Skip to content

Commit 678d91c

Browse files
JoshRosenyhuai
authored andcommitted
[SPARK-18761][BRANCH-2.0] Introduce "task reaper" to oversee task killing in executors
Branch-2.0 backport of #16189; original description follows: ## What changes were proposed in this pull request? Spark's current task cancellation / task killing mechanism is "best effort" because some tasks may not be interruptible or may not respond to their "killed" flags being set. If a significant fraction of a cluster's task slots are occupied by tasks that have been marked as killed but remain running then this can lead to a situation where new jobs and tasks are starved of resources that are being used by these zombie tasks. This patch aims to address this problem by adding a "task reaper" mechanism to executors. At a high-level, task killing now launches a new thread which attempts to kill the task and then watches the task and periodically checks whether it has been killed. The TaskReaper will periodically re-attempt to call `TaskRunner.kill()` and will log warnings if the task keeps running. I modified TaskRunner to rename its thread at the start of the task, allowing TaskReaper to take a thread dump and filter it in order to log stacktraces from the exact task thread that we are waiting to finish. If the task has not stopped after a configurable timeout then the TaskReaper will throw an exception to trigger executor JVM death, thereby forcibly freeing any resources consumed by the zombie tasks. This feature is flagged off by default and is controlled by four new configurations under the `spark.task.reaper.*` namespace. See the updated `configuration.md` doc for details. ## How was this patch tested? Tested via a new test case in `JobCancellationSuite`, plus manual testing. Author: Josh Rosen <joshrosen@databricks.com> Closes #16358 from JoshRosen/cancellation-branch-2.0.
1 parent 1f0c5fa commit 678d91c

File tree

4 files changed

+300
-14
lines changed

4 files changed

+300
-14
lines changed

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 160 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@ private[spark] class Executor(
8484
// Start worker thread pool
8585
private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker")
8686
private val executorSource = new ExecutorSource(threadPool, executorId)
87+
// Pool used for threads that supervise task killing / cancellation
88+
private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper")
89+
// For tasks which are in the process of being killed, this map holds the most recently created
90+
// TaskReaper. All accesses to this map should be synchronized on the map itself (this isn't
91+
// a ConcurrentHashMap because we use the synchronization for purposes other than simply guarding
92+
// the integrity of the map's internal state). The purpose of this map is to prevent the creation
93+
// of a separate TaskReaper for every killTask() of a given task. Instead, this map allows us to
94+
// track whether an existing TaskReaper fulfills the role of a TaskReaper that we would otherwise
95+
// create. The map key is a task id.
96+
private val taskReaperForTask: HashMap[Long, TaskReaper] = HashMap[Long, TaskReaper]()
8797

8898
if (!isLocal) {
8999
env.metricsSystem.registerSource(executorSource)
@@ -93,6 +103,9 @@ private[spark] class Executor(
93103
// Whether to load classes in user jars before those in Spark jars
94104
private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false)
95105

106+
// Whether to monitor killed / interrupted tasks
107+
private val taskReaperEnabled = conf.getBoolean("spark.task.reaper.enabled", false)
108+
96109
// Create our ClassLoader
97110
// do this after SparkEnv creation so can access the SecurityManager
98111
private val urlClassLoader = createClassLoader()
@@ -148,9 +161,27 @@ private[spark] class Executor(
148161
}
149162

150163
def killTask(taskId: Long, interruptThread: Boolean): Unit = {
151-
val tr = runningTasks.get(taskId)
152-
if (tr != null) {
153-
tr.kill(interruptThread)
164+
val taskRunner = runningTasks.get(taskId)
165+
if (taskRunner != null) {
166+
if (taskReaperEnabled) {
167+
val maybeNewTaskReaper: Option[TaskReaper] = taskReaperForTask.synchronized {
168+
val shouldCreateReaper = taskReaperForTask.get(taskId) match {
169+
case None => true
170+
case Some(existingReaper) => interruptThread && !existingReaper.interruptThread
171+
}
172+
if (shouldCreateReaper) {
173+
val taskReaper = new TaskReaper(taskRunner, interruptThread = interruptThread)
174+
taskReaperForTask(taskId) = taskReaper
175+
Some(taskReaper)
176+
} else {
177+
None
178+
}
179+
}
180+
// Execute the TaskReaper from outside of the synchronized block.
181+
maybeNewTaskReaper.foreach(taskReaperPool.execute)
182+
} else {
183+
taskRunner.kill(interruptThread = interruptThread)
184+
}
154185
}
155186
}
156187

@@ -161,12 +192,7 @@ private[spark] class Executor(
161192
* @param interruptThread whether to interrupt the task thread
162193
*/
163194
def killAllTasks(interruptThread: Boolean) : Unit = {
164-
// kill all the running tasks
165-
for (taskRunner <- runningTasks.values().asScala) {
166-
if (taskRunner != null) {
167-
taskRunner.kill(interruptThread)
168-
}
169-
}
195+
runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread))
170196
}
171197

172198
def stop(): Unit = {
@@ -192,13 +218,21 @@ private[spark] class Executor(
192218
serializedTask: ByteBuffer)
193219
extends Runnable {
194220

221+
val threadName = s"Executor task launch worker for task $taskId"
222+
195223
/** Whether this task has been killed. */
196224
@volatile private var killed = false
197225

226+
@volatile private var threadId: Long = -1
227+
228+
def getThreadId: Long = threadId
229+
198230
/** Whether this task has been finished. */
199231
@GuardedBy("TaskRunner.this")
200232
private var finished = false
201233

234+
def isFinished: Boolean = synchronized { finished }
235+
202236
/** How much the JVM process has spent in GC when the task starts to run. */
203237
@volatile var startGCTime: Long = _
204238

@@ -229,9 +263,15 @@ private[spark] class Executor(
229263
// ClosedByInterruptException during execBackend.statusUpdate which causes
230264
// Executor to crash
231265
Thread.interrupted()
266+
// Notify any waiting TaskReapers. Generally there will only be one reaper per task but there
267+
// is a rare corner-case where one task can have two reapers in case cancel(interrupt=False)
268+
// is followed by cancel(interrupt=True). Thus we use notifyAll() to avoid a lost wakeup:
269+
notifyAll()
232270
}
233271

234272
override def run(): Unit = {
273+
threadId = Thread.currentThread.getId
274+
Thread.currentThread.setName(threadName)
235275
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
236276
val deserializeStartTime = System.currentTimeMillis()
237277
Thread.currentThread.setContextClassLoader(replClassLoader)
@@ -411,6 +451,117 @@ private[spark] class Executor(
411451
}
412452
}
413453

454+
/**
455+
* Supervises the killing / cancellation of a task by sending the interrupted flag, optionally
456+
* sending a Thread.interrupt(), and monitoring the task until it finishes.
457+
*
458+
* Spark's current task cancellation / task killing mechanism is "best effort" because some tasks
459+
* may not be interruptable or may not respond to their "killed" flags being set. If a significant
460+
* fraction of a cluster's task slots are occupied by tasks that have been marked as killed but
461+
* remain running then this can lead to a situation where new jobs and tasks are starved of
462+
* resources that are being used by these zombie tasks.
463+
*
464+
* The TaskReaper was introduced in SPARK-18761 as a mechanism to monitor and clean up zombie
465+
* tasks. For backwards-compatibility / backportability this component is disabled by default
466+
* and must be explicitly enabled by setting `spark.task.reaper.enabled=true`.
467+
*
468+
* A TaskReaper is created for a particular task when that task is killed / cancelled. Typically
469+
* a task will have only one TaskReaper, but it's possible for a task to have up to two reapers
470+
* in case kill is called twice with different values for the `interrupt` parameter.
471+
*
472+
* Once created, a TaskReaper will run until its supervised task has finished running. If the
473+
* TaskReaper has not been configured to kill the JVM after a timeout (i.e. if
474+
* `spark.task.reaper.killTimeout < 0`) then this implies that the TaskReaper may run indefinitely
475+
* if the supervised task never exits.
476+
*/
477+
private class TaskReaper(
478+
taskRunner: TaskRunner,
479+
val interruptThread: Boolean)
480+
extends Runnable {
481+
482+
private[this] val taskId: Long = taskRunner.taskId
483+
484+
private[this] val killPollingIntervalMs: Long =
485+
conf.getTimeAsMs("spark.task.reaper.pollingInterval", "10s")
486+
487+
private[this] val killTimeoutMs: Long = conf.getTimeAsMs("spark.task.reaper.killTimeout", "-1")
488+
489+
private[this] val takeThreadDump: Boolean =
490+
conf.getBoolean("spark.task.reaper.threadDump", true)
491+
492+
override def run(): Unit = {
493+
val startTimeMs = System.currentTimeMillis()
494+
def elapsedTimeMs = System.currentTimeMillis() - startTimeMs
495+
def timeoutExceeded(): Boolean = killTimeoutMs > 0 && elapsedTimeMs > killTimeoutMs
496+
try {
497+
// Only attempt to kill the task once. If interruptThread = false then a second kill
498+
// attempt would be a no-op and if interruptThread = true then it may not be safe or
499+
// effective to interrupt multiple times:
500+
taskRunner.kill(interruptThread = interruptThread)
501+
// Monitor the killed task until it exits. The synchronization logic here is complicated
502+
// because we don't want to synchronize on the taskRunner while possibly taking a thread
503+
// dump, but we also need to be careful to avoid races between checking whether the task
504+
// has finished and wait()ing for it to finish.
505+
var finished: Boolean = false
506+
while (!finished && !timeoutExceeded()) {
507+
taskRunner.synchronized {
508+
// We need to synchronize on the TaskRunner while checking whether the task has
509+
// finished in order to avoid a race where the task is marked as finished right after
510+
// we check and before we call wait().
511+
if (taskRunner.isFinished) {
512+
finished = true
513+
} else {
514+
taskRunner.wait(killPollingIntervalMs)
515+
}
516+
}
517+
if (taskRunner.isFinished) {
518+
finished = true
519+
} else {
520+
logWarning(s"Killed task $taskId is still running after $elapsedTimeMs ms")
521+
if (takeThreadDump) {
522+
try {
523+
Utils.getThreadDumpForThread(taskRunner.getThreadId).foreach { thread =>
524+
if (thread.threadName == taskRunner.threadName) {
525+
logWarning(s"Thread dump from task $taskId:\n${thread.stackTrace}")
526+
}
527+
}
528+
} catch {
529+
case NonFatal(e) =>
530+
logWarning("Exception thrown while obtaining thread dump: ", e)
531+
}
532+
}
533+
}
534+
}
535+
536+
if (!taskRunner.isFinished && timeoutExceeded()) {
537+
if (isLocal) {
538+
logError(s"Killed task $taskId could not be stopped within $killTimeoutMs ms; " +
539+
"not killing JVM because we are running in local mode.")
540+
} else {
541+
// In non-local-mode, the exception thrown here will bubble up to the uncaught exception
542+
// handler and cause the executor JVM to exit.
543+
throw new SparkException(
544+
s"Killing executor JVM because killed task $taskId could not be stopped within " +
545+
s"$killTimeoutMs ms.")
546+
}
547+
}
548+
} finally {
549+
// Clean up entries in the taskReaperForTask map.
550+
taskReaperForTask.synchronized {
551+
taskReaperForTask.get(taskId).foreach { taskReaperInMap =>
552+
if (taskReaperInMap eq this) {
553+
taskReaperForTask.remove(taskId)
554+
} else {
555+
// This must have been a TaskReaper where interruptThread == false where a subsequent
556+
// killTask() call for the same task had interruptThread == true and overwrote the
557+
// map entry.
558+
}
559+
}
560+
}
561+
}
562+
}
563+
}
564+
414565
/**
415566
* Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
416567
* created by the interpreter to the search path

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.util
1919

2020
import java.io._
21-
import java.lang.management.ManagementFactory
21+
import java.lang.management.{ManagementFactory, ThreadInfo}
2222
import java.net._
2323
import java.nio.ByteBuffer
2424
import java.nio.channels.Channels
@@ -2112,13 +2112,29 @@ private[spark] object Utils extends Logging {
21122112
// We need to filter out null values here because dumpAllThreads() may return null array
21132113
// elements for threads that are dead / don't exist.
21142114
val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null)
2115-
threadInfos.sortBy(_.getThreadId).map { case threadInfo =>
2116-
val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n")
2117-
ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName,
2118-
threadInfo.getThreadState, stackTrace)
2115+
threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace)
2116+
}
2117+
2118+
def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = {
2119+
if (threadId <= 0) {
2120+
None
2121+
} else {
2122+
// The Int.MaxValue here requests the entire untruncated stack trace of the thread:
2123+
val threadInfo =
2124+
Option(ManagementFactory.getThreadMXBean.getThreadInfo(threadId, Int.MaxValue))
2125+
threadInfo.map(threadInfoToThreadStackTrace)
21192126
}
21202127
}
21212128

2129+
private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): ThreadStackTrace = {
2130+
val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n")
2131+
ThreadStackTrace(
2132+
threadId = threadInfo.getThreadId,
2133+
threadName = threadInfo.getThreadName,
2134+
threadState = threadInfo.getThreadState,
2135+
stackTrace = stackTrace)
2136+
}
2137+
21222138
/**
21232139
* Convert all spark properties set in the given SparkConf to a sequence of java options.
21242140
*/

core/src/test/scala/org/apache/spark/JobCancellationSuite.scala

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,83 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
209209
assert(jobB.get() === 100)
210210
}
211211

212+
test("task reaper kills JVM if killed tasks keep running for too long") {
213+
val conf = new SparkConf()
214+
.set("spark.task.reaper.enabled", "true")
215+
.set("spark.task.reaper.killTimeout", "5s")
216+
sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
217+
218+
// Add a listener to release the semaphore once any tasks are launched.
219+
val sem = new Semaphore(0)
220+
sc.addSparkListener(new SparkListener {
221+
override def onTaskStart(taskStart: SparkListenerTaskStart) {
222+
sem.release()
223+
}
224+
})
225+
226+
// jobA is the one to be cancelled.
227+
val jobA = Future {
228+
sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true)
229+
sc.parallelize(1 to 10000, 2).map { i =>
230+
while (true) { }
231+
}.count()
232+
}
233+
234+
// Block until both tasks of job A have started and cancel job A.
235+
sem.acquire(2)
236+
// Small delay to ensure tasks actually start executing the task body
237+
Thread.sleep(1000)
238+
239+
sc.clearJobGroup()
240+
val jobB = sc.parallelize(1 to 100, 2).countAsync()
241+
sc.cancelJobGroup("jobA")
242+
val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 15.seconds) }.getCause
243+
assert(e.getMessage contains "cancel")
244+
245+
// Once A is cancelled, job B should finish fairly quickly.
246+
assert(ThreadUtils.awaitResult(jobB, 60.seconds) === 100)
247+
}
248+
249+
test("task reaper will not kill JVM if spark.task.killTimeout == -1") {
250+
val conf = new SparkConf()
251+
.set("spark.task.reaper.enabled", "true")
252+
.set("spark.task.reaper.killTimeout", "-1")
253+
.set("spark.task.reaper.PollingInterval", "1s")
254+
.set("spark.deploy.maxExecutorRetries", "1")
255+
sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
256+
257+
// Add a listener to release the semaphore once any tasks are launched.
258+
val sem = new Semaphore(0)
259+
sc.addSparkListener(new SparkListener {
260+
override def onTaskStart(taskStart: SparkListenerTaskStart) {
261+
sem.release()
262+
}
263+
})
264+
265+
// jobA is the one to be cancelled.
266+
val jobA = Future {
267+
sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true)
268+
sc.parallelize(1 to 2, 2).map { i =>
269+
val startTime = System.currentTimeMillis()
270+
while (System.currentTimeMillis() < startTime + 10000) { }
271+
}.count()
272+
}
273+
274+
// Block until both tasks of job A have started and cancel job A.
275+
sem.acquire(2)
276+
// Small delay to ensure tasks actually start executing the task body
277+
Thread.sleep(1000)
278+
279+
sc.clearJobGroup()
280+
val jobB = sc.parallelize(1 to 100, 2).countAsync()
281+
sc.cancelJobGroup("jobA")
282+
val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 15.seconds) }.getCause
283+
assert(e.getMessage contains "cancel")
284+
285+
// Once A is cancelled, job B should finish fairly quickly.
286+
assert(ThreadUtils.awaitResult(jobB, 60.seconds) === 100)
287+
}
288+
212289
test("two jobs sharing the same stage") {
213290
// sem1: make sure cancel is issued after some tasks are launched
214291
// twoJobsSharingStageSemaphore:

0 commit comments

Comments
 (0)