Skip to content
5 changes: 5 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ abstract class TaskContext extends Serializable {
* This will be called in all situations - success, failure, or cancellation. Adding a listener
* to an already completed task will result in that listener being called immediately.
*
* Two listeners registered in the same thread will be invoked in reverse order of registration if
* the task completes after both are registered. There are no ordering guarantees for listeners
* registered in different threads, or for listeners registered after the task completes.
* Listeners are guaranteed to execute sequentially.
*
* An example use is for HadoopRDD to register a callback to close the input stream.
*
* Exceptions thrown by the listener will result in failure of the task.
Expand Down
144 changes: 102 additions & 42 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark

import java.util.Properties
import java.util.{Properties, Stack}
import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
Expand All @@ -39,9 +39,9 @@ import org.apache.spark.util._
* A small note on thread safety. The interrupted & fetchFailed fields are volatile, this makes
* sure that updates are always visible across threads. The complete & failed flags and their
* callbacks are protected by locking on the context instance. For instance, this ensures
* that you cannot add a completion listener in one thread while we are completing (and calling
* the completion listeners) in another thread. Other state is immutable, however the exposed
* `TaskMetrics` & `MetricsSystem` objects are not thread safe.
* that you cannot add a completion listener in one thread while we are completing in another
* thread. Other state is immutable, however the exposed `TaskMetrics` & `MetricsSystem` objects are
* not thread safe.
*/
private[spark] class TaskContextImpl(
override val stageId: Int,
Expand All @@ -59,81 +59,141 @@ private[spark] class TaskContextImpl(
extends TaskContext
with Logging {

/** List of callback functions to execute when the task completes. */
@transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
/**
* List of callback functions to execute when the task completes.
*
* Using a stack causes us to process listeners in reverse order of registration. As listeners are
* invoked, they are popped from the stack.
*/
@transient private val onCompleteCallbacks = new Stack[TaskCompletionListener]

/** List of callback functions to execute when the task fails. */
@transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener]
@transient private val onFailureCallbacks = new Stack[TaskFailureListener]

/**
* The thread currently executing task completion or failure listeners, if any.
*
* `invokeListeners()` uses this to ensure listeners are called sequentially.
*/
@transient private var listenerInvocationThread: Option[Thread] = None

// If defined, the corresponding task has been killed and this option contains the reason.
@volatile private var reasonIfKilled: Option[String] = None

// Whether the task has completed.
private var completed: Boolean = false

// Whether the task has failed.
private var failed: Boolean = false

// Throwable that caused the task to fail
private var failure: Throwable = _
// If defined, the task has failed and this option contains the Throwable that caused the task to
// fail.
private var failureCauseOpt: Option[Throwable] = None

// If there was a fetch failure in the task, we store it here, to make sure user-code doesn't
// hide the exception. See SPARK-19276
@volatile private var _fetchFailedException: Option[FetchFailedException] = None

@GuardedBy("this")
override def addTaskCompletionListener(listener: TaskCompletionListener)
: this.type = synchronized {
if (completed) {
listener.onTaskCompletion(this)
} else {
onCompleteCallbacks += listener
override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
val needToCallListener = synchronized {
// If there is already a thread invoking listeners, adding the new listener to
// `onCompleteCallbacks` will cause that thread to execute the new listener, and the call to
// `invokeTaskCompletionListeners()` below will be a no-op.
//
// If there is no such thread, the call to `invokeTaskCompletionListeners()` below will
// execute all listeners, including the new listener.
onCompleteCallbacks.push(listener)
completed
}
Copy link
Contributor

@timarmstrong timarmstrong Oct 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The API doesn't intend to guarantee any ordering of when the task completion listeners are called AFAICT. I think before this change the implementation ends up guaranteeing that the listeners are called sequentially. So it seems possible that some code could be accidentally depending on that.

This might be overengineering it, but we could have a scheme that avoided the deadlock issues and guaranteed sequential execution of callbacks. You would have at most one single thread at any point in time responsible for invoking callbacks. If another thread needs to invoke a callback, it either delegates it to the current callback invocation thread, or it becomes the callback execution thread itself. This means that the callback invocation thread needs to first invoke all of the current registered callbacks, but when it's done with those, check to see if any more callbacks have been queued.

I think we could do that by having the callback invocation thread taking ownership of the current callbacks list, but after invoking those callbacks checking to see if any more have been queued. We'd also need a variable to track if there's a current callback execution thread.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, good point that we'd be changing the behavior of this API. It would be nice to preserve the sequential execution behavior, but it does seem pretty complex. I can try implementing it and see whether it's worth it.

Either way, we should probably document and test the behavior more thoroughly. In the current state of the PR, I think the guarantee is something like the following: "Two listeners registered in the same thread will be invoked in reverse order of registration if the task finishes after both are registered. There are no ordering guarantees for listeners registered in different threads, and they may execute concurrently."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no ordering guarantees for listeners registered in different threads

I agree. When there are multiple threads I don't think we can define an "order".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@timarmstrong I implemented your suggestion to ensure sequential execution of listeners - it wasn't too complex after all. I also added tests to verify sequential execution, ordering, and liveness in case of reentrancy.

if (needToCallListener) {
invokeTaskCompletionListeners(None)
}
this
}

@GuardedBy("this")
override def addTaskFailureListener(listener: TaskFailureListener)
: this.type = synchronized {
if (failed) {
listener.onTaskFailure(this, failure)
} else {
onFailureCallbacks += listener
}
override def addTaskFailureListener(listener: TaskFailureListener): this.type = {
synchronized {
onFailureCallbacks.push(listener)
failureCauseOpt
}.foreach(invokeTaskFailureListeners)
this
}

override def resourcesJMap(): java.util.Map[String, ResourceInformation] = {
resources.asJava
}

@GuardedBy("this")
private[spark] override def markTaskFailed(error: Throwable): Unit = synchronized {
if (failed) return
failed = true
failure = error
invokeListeners(onFailureCallbacks.toSeq, "TaskFailureListener", Option(error)) {
_.onTaskFailure(this, error)
private[spark] override def markTaskFailed(error: Throwable): Unit = {
synchronized {
if (failureCauseOpt.isDefined) return
failureCauseOpt = Some(error)
}
invokeTaskFailureListeners(error)
}

@GuardedBy("this")
private[spark] override def markTaskCompleted(error: Option[Throwable]): Unit = synchronized {
if (completed) return
completed = true
invokeListeners(onCompleteCallbacks.toSeq, "TaskCompletionListener", error) {
private[spark] override def markTaskCompleted(error: Option[Throwable]): Unit = {
synchronized {
if (completed) return
completed = true
}
invokeTaskCompletionListeners(error)
}

private def invokeTaskCompletionListeners(error: Option[Throwable]): Unit = {
// It is safe to access the reference to `onCompleteCallbacks` without holding the TaskContext
// lock. `invokeListeners()` acquires the lock before accessing the contents.
invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) {
_.onTaskCompletion(this)
}
}

private def invokeTaskFailureListeners(error: Throwable): Unit = {
// It is safe to access the reference to `onFailureCallbacks` without holding the TaskContext
// lock. `invokeListeners()` acquires the lock before accessing the contents.
invokeListeners(onFailureCallbacks, "TaskFailureListener", Option(error)) {
_.onTaskFailure(this, error)
}
}

private def invokeListeners[T](
listeners: Seq[T],
listeners: Stack[T],
name: String,
error: Option[Throwable])(
callback: T => Unit): Unit = {
// This method is subject to two constraints:
//
// 1. Listeners must be run sequentially to uphold the guarantee provided by the TaskContext
// API.
//
// 2. Listeners may spawn threads that call methods on this TaskContext. To avoid deadlock, we
// cannot call listeners while holding the TaskContext lock.
//
// We meet these constraints by ensuring there is at most one thread invoking listeners at any
// point in time.
synchronized {
if (listenerInvocationThread.nonEmpty) {
// If another thread is already invoking listeners, do nothing.
return
} else {
// If no other thread is invoking listeners, register this thread as the listener invocation
// thread. This prevents other threads from invoking listeners until this thread is
// deregistered.
listenerInvocationThread = Some(Thread.currentThread())
}
}

def getNextListenerOrDeregisterThread(): Option[T] = synchronized {
if (listeners.empty()) {
// We have executed all listeners that have been added so far. Deregister this thread as the
// callback invocation thread.
listenerInvocationThread = None
None
} else {
Some(listeners.pop())
}
}

val errorMsgs = new ArrayBuffer[String](2)
// Process callbacks in the reverse order of registration
listeners.reverse.foreach { listener =>
var listenerOption: Option[T] = None
while ({listenerOption = getNextListenerOrDeregisterThread(); listenerOption.nonEmpty}) {
val listener = listenerOption.get
try {
callback(listener)
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,20 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
/** Contains the throwable thrown while writing the parent iterator to the Python process. */
def exception: Option[Throwable] = Option(_exception)

/** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
/**
* Terminates the writer thread and waits for it to exit, ignoring any exceptions that may occur
* due to cleanup.
*/
def shutdownOnTaskCompletion(): Unit = {
assert(context.isCompleted)
this.interrupt()
// Task completion listeners that run after this method returns may invalidate
// `inputIterator`. For example, when `inputIterator` was generated by the off-heap vectorized
// reader, a task completion listener will free the underlying off-heap buffers. If the writer
// thread is still running when `inputIterator` is invalidated, it can cause a use-after-free
// bug that crashes the executor (SPARK-33277). Therefore this method must wait for the writer
// thread to exit before returning.
this.join()
}

/**
Expand Down
121 changes: 121 additions & 0 deletions core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
package org.apache.spark.scheduler

import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.ArrayBuffer

import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
Expand Down Expand Up @@ -334,6 +337,124 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
assert(e.getMessage.contains("exception in task"))
}

test("listener registers another listener (reentrancy)") {
val context = TaskContext.empty()
var invocations = 0
val simpleListener = new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
invocations += 1
}
}

// Create a listener that registers another listener.
val reentrantListener = new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
context.addTaskCompletionListener(simpleListener)
invocations += 1
}
}
context.addTaskCompletionListener(reentrantListener)

// Ensure the listener can execute without encountering deadlock.
assert(invocations == 0)
context.markTaskCompleted(None)
assert(invocations == 2)
}

test("listener registers another listener using a second thread") {
val context = TaskContext.empty()
val invocations = new AtomicInteger(0)
val simpleListener = new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
invocations.getAndIncrement()
}
}

// Create a listener that registers another listener using a second thread.
val multithreadedListener = new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
val thread = new Thread(new Runnable {
override def run(): Unit = {
context.addTaskCompletionListener(simpleListener)
}
})
thread.start()
invocations.getAndIncrement()
thread.join()
}
}
context.addTaskCompletionListener(multithreadedListener)

// Ensure the listener can execute without encountering deadlock.
assert(invocations.get() == 0)
context.markTaskCompleted(None)
assert(invocations.get() == 2)
}

test("listeners registered from different threads are called sequentially") {
val context = TaskContext.empty()
val invocations = new AtomicInteger(0)
val numRunningListeners = new AtomicInteger(0)

// Create a listener that will throw if more than one instance is running at the same time.
val registerExclusiveListener = new Runnable {
override def run(): Unit = {
context.addTaskCompletionListener(new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
if (numRunningListeners.getAndIncrement() != 0) throw new Exception()
Thread.sleep(100)
if (numRunningListeners.decrementAndGet() != 0) throw new Exception()
invocations.getAndIncrement()
}
})
}
}

// Register it multiple times from different threads before and after the task completes.
assert(invocations.get() == 0)
assert(numRunningListeners.get() == 0)
val thread1 = new Thread(registerExclusiveListener)
val thread2 = new Thread(registerExclusiveListener)
thread1.start()
thread2.start()
thread1.join()
thread2.join()
assert(invocations.get() == 0)
context.markTaskCompleted(None)
assert(invocations.get() == 2)
val thread3 = new Thread(registerExclusiveListener)
val thread4 = new Thread(registerExclusiveListener)
thread3.start()
thread4.start()
thread3.join()
thread4.join()
assert(invocations.get() == 4)
assert(numRunningListeners.get() == 0)
}

test("listeners registered from same thread are called in reverse order") {
val context = TaskContext.empty()
val invocationOrder = ArrayBuffer.empty[String]

// Create listeners that log an id to `invocationOrder` when they are invoked.
def makeLoggingListener(id: String): TaskCompletionListener = new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
invocationOrder += id
}
}
context.addTaskCompletionListener(makeLoggingListener("A"))
context.addTaskCompletionListener(makeLoggingListener("B"))
context.addTaskCompletionListener(makeLoggingListener("C"))

// Ensure the listeners are called in reverse order of registration, except when they are called
// after the task is complete.
assert(invocationOrder === Seq.empty)
context.markTaskCompleted(None)
assert(invocationOrder === Seq("C", "B", "A"))
context.addTaskCompletionListener(makeLoggingListener("D"))
assert(invocationOrder === Seq("C", "B", "A", "D"))
}

}

private object TaskContextSuite {
Expand Down
Loading