@@ -84,6 +84,16 @@ private[spark] class Executor(
8484 // Start worker thread pool
8585 private val threadPool = ThreadUtils .newDaemonCachedThreadPool(" Executor task launch worker" )
8686 private val executorSource = new ExecutorSource (threadPool, executorId)
87+ // Pool used for threads that supervise task killing / cancellation
88+ private val taskReaperPool = ThreadUtils .newDaemonCachedThreadPool(" Task reaper" )
89+ // For tasks which are in the process of being killed, this map holds the most recently created
90+ // TaskReaper. All accesses to this map should be synchronized on the map itself (this isn't
91+ // a ConcurrentHashMap because we use the synchronization for purposes other than simply guarding
92+ // the integrity of the map's internal state). The purpose of this map is to prevent the creation
93+ // of a separate TaskReaper for every killTask() of a given task. Instead, this map allows us to
94+ // track whether an existing TaskReaper fulfills the role of a TaskReaper that we would otherwise
95+ // create. The map key is a task id.
96+ private val taskReaperForTask : HashMap [Long , TaskReaper ] = HashMap [Long , TaskReaper ]()
8797
8898 if (! isLocal) {
8999 env.metricsSystem.registerSource(executorSource)
@@ -93,6 +103,9 @@ private[spark] class Executor(
93103 // Whether to load classes in user jars before those in Spark jars
94104 private val userClassPathFirst = conf.getBoolean(" spark.executor.userClassPathFirst" , false )
95105
106+ // Whether to monitor killed / interrupted tasks
107+ private val taskReaperEnabled = conf.getBoolean(" spark.task.reaper.enabled" , false )
108+
96109 // Create our ClassLoader
97110 // do this after SparkEnv creation so can access the SecurityManager
98111 private val urlClassLoader = createClassLoader()
@@ -148,9 +161,27 @@ private[spark] class Executor(
148161 }
149162
150163 def killTask (taskId : Long , interruptThread : Boolean ): Unit = {
151- val tr = runningTasks.get(taskId)
152- if (tr != null ) {
153- tr.kill(interruptThread)
164+ val taskRunner = runningTasks.get(taskId)
165+ if (taskRunner != null ) {
166+ if (taskReaperEnabled) {
167+ val maybeNewTaskReaper : Option [TaskReaper ] = taskReaperForTask.synchronized {
168+ val shouldCreateReaper = taskReaperForTask.get(taskId) match {
169+ case None => true
170+ case Some (existingReaper) => interruptThread && ! existingReaper.interruptThread
171+ }
172+ if (shouldCreateReaper) {
173+ val taskReaper = new TaskReaper (taskRunner, interruptThread = interruptThread)
174+ taskReaperForTask(taskId) = taskReaper
175+ Some (taskReaper)
176+ } else {
177+ None
178+ }
179+ }
180+ // Execute the TaskReaper from outside of the synchronized block.
181+ maybeNewTaskReaper.foreach(taskReaperPool.execute)
182+ } else {
183+ taskRunner.kill(interruptThread = interruptThread)
184+ }
154185 }
155186 }
156187
@@ -161,12 +192,7 @@ private[spark] class Executor(
161192 * @param interruptThread whether to interrupt the task thread
162193 */
163194 def killAllTasks (interruptThread : Boolean ) : Unit = {
164- // kill all the running tasks
165- for (taskRunner <- runningTasks.values().asScala) {
166- if (taskRunner != null ) {
167- taskRunner.kill(interruptThread)
168- }
169- }
195+ runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread))
170196 }
171197
172198 def stop (): Unit = {
@@ -192,13 +218,21 @@ private[spark] class Executor(
192218 serializedTask : ByteBuffer )
193219 extends Runnable {
194220
221+ val threadName = s " Executor task launch worker for task $taskId"
222+
195223 /** Whether this task has been killed. */
196224 @ volatile private var killed = false
197225
226+ @ volatile private var threadId : Long = - 1
227+
228+ def getThreadId : Long = threadId
229+
198230 /** Whether this task has been finished. */
199231 @ GuardedBy (" TaskRunner.this" )
200232 private var finished = false
201233
234+ def isFinished : Boolean = synchronized { finished }
235+
202236 /** How much the JVM process has spent in GC when the task starts to run. */
203237 @ volatile var startGCTime : Long = _
204238
@@ -229,9 +263,15 @@ private[spark] class Executor(
229263 // ClosedByInterruptException during execBackend.statusUpdate which causes
230264 // Executor to crash
231265 Thread .interrupted()
266+ // Notify any waiting TaskReapers. Generally there will only be one reaper per task but there
267+ // is a rare corner-case where one task can have two reapers in case cancel(interrupt=False)
268+ // is followed by cancel(interrupt=True). Thus we use notifyAll() to avoid a lost wakeup:
269+ notifyAll()
232270 }
233271
234272 override def run (): Unit = {
273+ threadId = Thread .currentThread.getId
274+ Thread .currentThread.setName(threadName)
235275 val taskMemoryManager = new TaskMemoryManager (env.memoryManager, taskId)
236276 val deserializeStartTime = System .currentTimeMillis()
237277 Thread .currentThread.setContextClassLoader(replClassLoader)
@@ -411,6 +451,117 @@ private[spark] class Executor(
411451 }
412452 }
413453
454+ /**
455+ * Supervises the killing / cancellation of a task by sending the interrupted flag, optionally
456+ * sending a Thread.interrupt(), and monitoring the task until it finishes.
457+ *
458+ * Spark's current task cancellation / task killing mechanism is "best effort" because some tasks
459+ * may not be interruptable or may not respond to their "killed" flags being set. If a significant
460+ * fraction of a cluster's task slots are occupied by tasks that have been marked as killed but
461+ * remain running then this can lead to a situation where new jobs and tasks are starved of
462+ * resources that are being used by these zombie tasks.
463+ *
464+ * The TaskReaper was introduced in SPARK-18761 as a mechanism to monitor and clean up zombie
465+ * tasks. For backwards-compatibility / backportability this component is disabled by default
466+ * and must be explicitly enabled by setting `spark.task.reaper.enabled=true`.
467+ *
468+ * A TaskReaper is created for a particular task when that task is killed / cancelled. Typically
469+ * a task will have only one TaskReaper, but it's possible for a task to have up to two reapers
470+ * in case kill is called twice with different values for the `interrupt` parameter.
471+ *
472+ * Once created, a TaskReaper will run until its supervised task has finished running. If the
473+ * TaskReaper has not been configured to kill the JVM after a timeout (i.e. if
474+ * `spark.task.reaper.killTimeout < 0`) then this implies that the TaskReaper may run indefinitely
475+ * if the supervised task never exits.
476+ */
477+ private class TaskReaper (
478+ taskRunner : TaskRunner ,
479+ val interruptThread : Boolean )
480+ extends Runnable {
481+
482+ private [this ] val taskId : Long = taskRunner.taskId
483+
484+ private [this ] val killPollingIntervalMs : Long =
485+ conf.getTimeAsMs(" spark.task.reaper.pollingInterval" , " 10s" )
486+
487+ private [this ] val killTimeoutMs : Long = conf.getTimeAsMs(" spark.task.reaper.killTimeout" , " -1" )
488+
489+ private [this ] val takeThreadDump : Boolean =
490+ conf.getBoolean(" spark.task.reaper.threadDump" , true )
491+
492+ override def run (): Unit = {
493+ val startTimeMs = System .currentTimeMillis()
494+ def elapsedTimeMs = System .currentTimeMillis() - startTimeMs
495+ def timeoutExceeded (): Boolean = killTimeoutMs > 0 && elapsedTimeMs > killTimeoutMs
496+ try {
497+ // Only attempt to kill the task once. If interruptThread = false then a second kill
498+ // attempt would be a no-op and if interruptThread = true then it may not be safe or
499+ // effective to interrupt multiple times:
500+ taskRunner.kill(interruptThread = interruptThread)
501+ // Monitor the killed task until it exits. The synchronization logic here is complicated
502+ // because we don't want to synchronize on the taskRunner while possibly taking a thread
503+ // dump, but we also need to be careful to avoid races between checking whether the task
504+ // has finished and wait()ing for it to finish.
505+ var finished : Boolean = false
506+ while (! finished && ! timeoutExceeded()) {
507+ taskRunner.synchronized {
508+ // We need to synchronize on the TaskRunner while checking whether the task has
509+ // finished in order to avoid a race where the task is marked as finished right after
510+ // we check and before we call wait().
511+ if (taskRunner.isFinished) {
512+ finished = true
513+ } else {
514+ taskRunner.wait(killPollingIntervalMs)
515+ }
516+ }
517+ if (taskRunner.isFinished) {
518+ finished = true
519+ } else {
520+ logWarning(s " Killed task $taskId is still running after $elapsedTimeMs ms " )
521+ if (takeThreadDump) {
522+ try {
523+ Utils .getThreadDumpForThread(taskRunner.getThreadId).foreach { thread =>
524+ if (thread.threadName == taskRunner.threadName) {
525+ logWarning(s " Thread dump from task $taskId: \n ${thread.stackTrace}" )
526+ }
527+ }
528+ } catch {
529+ case NonFatal (e) =>
530+ logWarning(" Exception thrown while obtaining thread dump: " , e)
531+ }
532+ }
533+ }
534+ }
535+
536+ if (! taskRunner.isFinished && timeoutExceeded()) {
537+ if (isLocal) {
538+ logError(s " Killed task $taskId could not be stopped within $killTimeoutMs ms; " +
539+ " not killing JVM because we are running in local mode." )
540+ } else {
541+ // In non-local-mode, the exception thrown here will bubble up to the uncaught exception
542+ // handler and cause the executor JVM to exit.
543+ throw new SparkException (
544+ s " Killing executor JVM because killed task $taskId could not be stopped within " +
545+ s " $killTimeoutMs ms. " )
546+ }
547+ }
548+ } finally {
549+ // Clean up entries in the taskReaperForTask map.
550+ taskReaperForTask.synchronized {
551+ taskReaperForTask.get(taskId).foreach { taskReaperInMap =>
552+ if (taskReaperInMap eq this ) {
553+ taskReaperForTask.remove(taskId)
554+ } else {
555+ // This must have been a TaskReaper where interruptThread == false where a subsequent
556+ // killTask() call for the same task had interruptThread == true and overwrote the
557+ // map entry.
558+ }
559+ }
560+ }
561+ }
562+ }
563+ }
564+
414565 /**
415566 * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
416567 * created by the interpreter to the search path
0 commit comments