From a1ff80e69d9a0f534e369a5ba1447c854ab40eaf Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 20 Mar 2026 14:14:40 +0000 Subject: [PATCH 1/2] [MINOR] Use SparkContext.isDriver() in GlobalSingletonManualClock --- .../sql/streaming/util/GlobalSingletonManualClock.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/GlobalSingletonManualClock.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/GlobalSingletonManualClock.scala index f7af0621e5935..84f24c470848a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/GlobalSingletonManualClock.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/GlobalSingletonManualClock.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.streaming.util -import org.apache.spark.SparkContext.DRIVER_IDENTIFIER -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.util.RpcUtils @@ -73,8 +72,7 @@ class GlobalManualClock(endpointName: String) private def isDriver: Boolean = { val executorId = SparkEnv.get.executorId - // Check for null to match the behavior of executorId == DRIVER_IDENTIFIER - executorId != null && executorId.startsWith(DRIVER_IDENTIFIER) + SparkContext.isDriver(executorId) } override def getTimeMillis(): Long = { From 06673cf1c023fb21915bdfb8d874d1091ba36942 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 21 Mar 2026 01:34:08 +0800 Subject: [PATCH 2/2] [MINOR] Add SparkContext.isDriver() and use it across the codebase Co-authored-by: Isaac --- .../org/apache/spark/profiler/SparkAsyncProfiler.scala | 5 ++--- core/src/main/scala/org/apache/spark/SparkContext.scala | 9 +++++++-- core/src/main/scala/org/apache/spark/SparkEnv.scala | 4 ++-- .../scala/org/apache/spark/rpc/netty/MessageLoop.scala | 2 +- .../scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala | 2 +- .../apache/spark/scheduler/EventLoggingListener.scala | 2 +- .../org/apache/spark/shuffle/ShuffleBlockPusher.scala | 2 +- .../org/apache/spark/status/AppStatusListener.scala | 6 +++--- .../scala/org/apache/spark/status/AppStatusStore.scala | 2 +- .../spark/status/api/v1/OneApplicationResource.scala | 2 +- .../scala/org/apache/spark/storage/BlockManager.scala | 2 +- .../scala/org/apache/spark/storage/BlockManagerId.scala | 2 +- .../scala/org/apache/spark/HeartbeatReceiverSuite.scala | 2 +- .../test/scala/org/apache/spark/SparkContextSuite.scala | 2 +- .../spark/scheduler/EventLoggingListenerSuite.scala | 4 ++-- .../BlockManagerDecommissionIntegrationSuite.scala | 2 +- .../spark/scheduler/cluster/k8s/ExecutorRollPlugin.scala | 2 +- .../spark/sql/execution/streaming/state/StateStore.scala | 5 ++--- .../spark/streaming/scheduler/ReceiverTracker.scala | 2 +- 19 files changed, 31 insertions(+), 28 deletions(-) diff --git a/connector/profiler/src/main/scala/org/apache/spark/profiler/SparkAsyncProfiler.scala b/connector/profiler/src/main/scala/org/apache/spark/profiler/SparkAsyncProfiler.scala index 02d39a2c435b3..f86b47e98d8fb 100644 --- a/connector/profiler/src/main/scala/org/apache/spark/profiler/SparkAsyncProfiler.scala +++ b/connector/profiler/src/main/scala/org/apache/spark/profiler/SparkAsyncProfiler.scala @@ -23,8 +23,7 @@ import one.profiler.{AsyncProfiler, AsyncProfilerLoader} import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path} import org.apache.hadoop.fs.permission.FsPermission -import org.apache.spark.SparkConf -import org.apache.spark.SparkContext.DRIVER_IDENTIFIER +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.PATH @@ -45,7 +44,7 @@ private[spark] class SparkAsyncProfiler(conf: SparkConf, executorId: String) ext private def getAppId: Option[String] = conf.getOption("spark.app.id") private def getAttemptId: Option[String] = conf.getOption("spark.app.attempt.id") - private val profileFile = if (executorId == DRIVER_IDENTIFIER) { + private val profileFile = if (SparkContext.isDriver(executorId)) { s"profile-$executorId.jfr" } else { s"profile-exec-$executorId.jfr" diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d6f8d40aa51b5..fad9bb522ad92 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -754,7 +754,7 @@ class SparkContext(config: SparkConf) extends Logging { */ private[spark] def getExecutorThreadDump(executorId: String): Option[Array[ThreadStackTrace]] = { try { - if (executorId == SparkContext.DRIVER_IDENTIFIER) { + if (SparkContext.isDriver(executorId)) { Some(Utils.getThreadDump()) } else { env.blockManager.master.getExecutorEndpointRef(executorId) match { @@ -786,7 +786,7 @@ class SparkContext(config: SparkConf) extends Logging { */ private[spark] def getExecutorHeapHistogram(executorId: String): Option[Array[String]] = { try { - if (executorId == SparkContext.DRIVER_IDENTIFIER) { + if (SparkContext.isDriver(executorId)) { Some(Utils.getHeapHistogram()) } else { env.blockManager.master.getExecutorEndpointRef(executorId) match { @@ -3163,6 +3163,11 @@ object SparkContext extends Logging { /** Separator of tags in SPARK_JOB_TAGS property */ private[spark] val SPARK_JOB_TAGS_SEP = "," + /** Returns true if the given executor ID identifies the driver. */ + private[spark] def isDriver(executorId: String): Boolean = { + DRIVER_IDENTIFIER == executorId + } + // Same rules apply to Spark Connect execution tags, see ExecuteHolder.throwIfInvalidTag private[spark] def throwIfInvalidTag(tag: String) = { if (tag == null) { diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 5374a8f5a015c..7dcf66a609577 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -248,7 +248,7 @@ class SparkEnv ( Preconditions.checkState(null == _shuffleManager, "Shuffle manager already initialized to %s", _shuffleManager) try { - _shuffleManager = ShuffleManager.create(conf, executorId == SparkContext.DRIVER_IDENTIFIER) + _shuffleManager = ShuffleManager.create(conf, SparkContext.isDriver(executorId)) } finally { // Signal that the ShuffleManager has been initialized shuffleManagerInitLatch.countDown() @@ -356,7 +356,7 @@ object SparkEnv extends Logging { listenerBus: LiveListenerBus = null, mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { - val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER + val isDriver = SparkContext.isDriver(executorId) // Listener bus is only used on the driver if (isDriver) { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala b/core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala index cce455270df43..0564c55e38b76 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala @@ -116,7 +116,7 @@ private class SharedMessageLoop( .getOrElse(math.max(2, availableCores)) conf.get(EXECUTOR_ID).map { id => - val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor" + val role = if (SparkContext.isDriver(id)) "driver" else "executor" conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads", modNumThreads) }.getOrElse(modNumThreads) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index a4f6d5438bf3d..eb956aa1582dd 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -50,7 +50,7 @@ private[netty] class NettyRpcEnv( securityManager: SecurityManager, numUsableCores: Int) extends RpcEnv(conf) with Logging { val role = conf.get(EXECUTOR_ID).map { id => - if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor" + if (SparkContext.isDriver(id)) "driver" else "executor" } private[netty] val transportConf = SparkTransportConf.fromSparkConf( diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 1e46142fab255..97391e093d92e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -249,7 +249,7 @@ private[spark] class EventLoggingListener( override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { if (shouldLogStageExecutorMetrics) { - if (event.execId == SparkContext.DRIVER_IDENTIFIER) { + if (SparkContext.isDriver(event.execId)) { logEvent(event) } event.executorUpdates.foreach { case (stageKey1, newPeaks) => diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala index 548ecb399d5ff..8e44f4c2ebab0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala @@ -502,7 +502,7 @@ private[spark] object ShuffleBlockPusher { private val BLOCK_PUSHER_POOL: ExecutorService = { val conf = SparkEnv.get.conf if (Utils.isPushBasedShuffleEnabled(conf, - isDriver = SparkContext.DRIVER_IDENTIFIER == SparkEnv.get.executorId)) { + isDriver = SparkContext.isDriver(SparkEnv.get.executorId))) { val numThreads = conf.get(SHUFFLE_NUM_PUSH_THREADS) .getOrElse(conf.getInt(SparkLauncher.EXECUTOR_CORES, 1)) ThreadUtils.newDaemonFixedThreadPool(numThreads, "shuffle-block-push-thread") diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 52856427cb37a..219dd49df5d27 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -362,11 +362,11 @@ private[spark] class AppStatusListener( // Implicitly exclude every available executor for the stage associated with this node Option(liveStages.get((stageId, stageAttemptId))).foreach { stage => val executorIds = liveExecutors.values.filter(exec => exec.host == hostId - && exec.executorId != SparkContext.DRIVER_IDENTIFIER).map(_.executorId).toSeq + && !SparkContext.isDriver(exec.executorId)).map(_.executorId).toSeq setStageExcludedStatus(stage, now, executorIds: _*) } liveExecutors.values.filter(exec => exec.hostname == hostId - && exec.executorId != SparkContext.DRIVER_IDENTIFIER).foreach { exec => + && !SparkContext.isDriver(exec.executorId)).foreach { exec => addExcludedStageTo(exec, stageId, now) } } @@ -413,7 +413,7 @@ private[spark] class AppStatusListener( // Implicitly (un)exclude every executor associated with the node. liveExecutors.values.foreach { exec => - if (exec.hostname == host && exec.executorId != SparkContext.DRIVER_IDENTIFIER) { + if (exec.hostname == host && !SparkContext.isDriver(exec.executorId)) { updateExecExclusionStatus(exec, excluded, now) } } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index a7f3fde9e6f6f..ccdfc6c48319a 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -103,7 +103,7 @@ private[spark] class AppStatusStore( } private def replaceExec(origin: v1.ExecutorSummary): v1.ExecutorSummary = { - if (origin.id == SparkContext.DRIVER_IDENTIFIER) { + if (SparkContext.isDriver(origin.id)) { replaceDriverGcTime(origin, extractGcTime(origin), extractAppTime) } else { origin diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index 604a5017db3b5..de25e7c524ead 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -182,7 +182,7 @@ private[v1] class AbstractApplicationResource extends BaseAppResource { } private def checkExecutorId(execId: String): Unit = { - if (execId != SparkContext.DRIVER_IDENTIFIER && !execId.forall(Character.isDigit)) { + if (!SparkContext.isDriver(execId) && !execId.forall(Character.isDigit)) { throw new BadParameterException( s"Invalid executorId: neither '${SparkContext.DRIVER_IDENTIFIER}' nor number.") } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 11d92e9982142..5fbc8dca74f68 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -213,7 +213,7 @@ private[spark] class BlockManager( // same as `conf.get(config.SHUFFLE_SERVICE_ENABLED)` private[spark] val externalShuffleServiceEnabled: Boolean = externalBlockStoreClient.isDefined - private val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER + private val isDriver = SparkContext.isDriver(executorId) private val remoteReadNioBufferConversion = conf.get(Network.NETWORK_REMOTE_READ_NIO_BUFFER_CONVERSION) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 5f09d845249d6..d41890919c332 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -65,7 +65,7 @@ class BlockManagerId private ( def topologyInfo: Option[String] = topologyInfo_ def isDriver: Boolean = { - executorId == SparkContext.DRIVER_IDENTIFIER + SparkContext.isDriver(executorId) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 817dbab2c22b0..7d1b8f8ea508b 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -305,7 +305,7 @@ class HeartbeatReceiverSuite // We may receive undesired SparkListenerExecutorAdded from LocalSchedulerBackend, // so exclude it from the map. See SPARK-10800. heartbeatReceiver.invokePrivate(_executorLastSeen()). - filter { case (k, _) => k != SparkContext.DRIVER_IDENTIFIER } + filter { case (k, _) => !SparkContext.isDriver(k) } } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 8dfb1ba67f1c1..307013a148d6a 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -849,7 +849,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu val listener = new SparkListener { override def onExecutorMetricsUpdate( executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = { - if (executorMetricsUpdate.execId != SparkContext.DRIVER_IDENTIFIER) { + if (!SparkContext.isDriver(executorMetricsUpdate.execId)) { runningTaskIds = executorMetricsUpdate.accumUpdates.map(_._1) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 74511d642729e..01035df87a169 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -531,7 +531,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit events.foreach { event => event match { case metricsUpdate: SparkListenerExecutorMetricsUpdate - if metricsUpdate.execId != SparkContext.DRIVER_IDENTIFIER => + if !SparkContext.isDriver(metricsUpdate.execId) => case stageCompleted: SparkListenerStageCompleted => val execIds = Set[String]() (1 to 3).foreach { _ => @@ -631,7 +631,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit case (expected: SparkListenerExecutorMetricsUpdate, actual: SparkListenerExecutorMetricsUpdate) => assert(expected.execId == actual.execId) - assert(expected.execId == SparkContext.DRIVER_IDENTIFIER) + assert(SparkContext.isDriver(expected.execId)) case (expected: SparkListenerEvent, actual: SparkListenerEvent) => assert(expected === actual) } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala index 8c6b9cc288ec1..4fd25c3d5ad91 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala @@ -207,7 +207,7 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS override def onExecutorMetricsUpdate( executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = { val executorId = executorMetricsUpdate.execId - if (executorId != SparkContext.DRIVER_IDENTIFIER) { + if (!SparkContext.isDriver(executorId)) { val validUpdate = executorMetricsUpdate .accumUpdates .flatMap(_._4) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorRollPlugin.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorRollPlugin.scala index c624050d819ff..9c318171f9cbe 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorRollPlugin.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorRollPlugin.scala @@ -111,7 +111,7 @@ class ExecutorRollDriverPlugin extends DriverPlugin with Logging { private def choose(list: Seq[v1.ExecutorSummary], policy: ExecutorRollPolicy.Value) : Option[String] = { val listWithoutDriver = list - .filterNot(_.id.equals(SparkContext.DRIVER_IDENTIFIER)) + .filterNot(e => SparkContext.isDriver(e.id)) .filter(_.totalTasks >= minTasks) val sortedList = policy match { case ExecutorRollPolicy.ID => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index ec0b8733ec67f..6e08c10476ce7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -988,7 +988,7 @@ object StateStoreProvider extends Logging { private[state] def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized { val env = SparkEnv.get if (env != null) { - val isDriver = env.executorId == SparkContext.DRIVER_IDENTIFIER + val isDriver = SparkContext.isDriver(env.executorId) // If running locally, then the coordinator reference in stateStoreCoordinatorRef may have // become inactive as SparkContext + SparkEnv may have been restarted. Hence, when running in // driver, always recreate the reference. @@ -1765,8 +1765,7 @@ object StateStore extends Logging { private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { - val isDriver = - env.executorId == SparkContext.DRIVER_IDENTIFIER + val isDriver = SparkContext.isDriver(env.executorId) // If running locally, then the coordinator reference in _coordRef may be have become inactive // as SparkContext + SparkEnv may have been restarted. Hence, when running in driver, // always recreate the reference. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index c2ca04c7be931..64717767b8453 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -414,7 +414,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false Seq(ExecutorCacheTaskLocation(blockManagerId.host, blockManagerId.executorId)) } else { ssc.sparkContext.env.blockManager.master.getMemoryStatus.filter { case (blockManagerId, _) => - blockManagerId.executorId != SparkContext.DRIVER_IDENTIFIER // Ignore the driver location + !SparkContext.isDriver(blockManagerId.executorId) // Ignore the driver location }.map { case (blockManagerId, _) => ExecutorCacheTaskLocation(blockManagerId.host, blockManagerId.executorId) }.toSeq