Skip to content

Commit

Permalink
[SPARK-33089][CORE] Enhance ExecutorPlugin API to include callbacks o…
Browse files Browse the repository at this point in the history
…n task start and end events

Proposing a new set of APIs for ExecutorPlugins, to provide callbacks invoked at the start and end of each task of a job. Not very opinionated on the shape of the API, tried to be as minimal as possible for now.

Changes described in detail on [SPARK-33088](https://issues.apache.org/jira/browse/SPARK-33088), but mostly this boils down to:

1. This feature was considered when the ExecutorPlugin API was initially introduced in apache#21923, but never implemented.
2. The use-case which **requires** this feature is to propagate tracing information from the driver to the executor, such that calls from the same job can all be traced.
  a. Tracing frameworks usually are setup in thread locals, therefore it's important for the setup to happen in the same thread which runs the tasks.
  b. Executors can be for multiple jobs, therefore it's not sufficient to set tracing information at executor startup time -- it needs to happen every time a task starts or ends.

No. This PR introduces new features for future developers to use.

Unit tests on `PluginContainerSuite`.

Closes apache#29977 from fsamuel-bs/SPARK-33088.

Authored-by: Samuel Souza <ssouza@palantir.com>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
  • Loading branch information
fsamuel-bs authored and rshkv committed Feb 26, 2021
1 parent 7423b57 commit 7739e49
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 15 deletions.
42 changes: 42 additions & 0 deletions core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.util.Map;

import org.apache.spark.TaskFailedReason;
import org.apache.spark.annotation.DeveloperApi;

/**
Expand Down Expand Up @@ -54,4 +55,45 @@ default void init(PluginContext ctx, Map<String, String> extraConf) {}
*/
default void shutdown() {}

/**
* Perform any action before the task is run.
* <p>
* This method is invoked from the same thread the task will be executed.
* Task-specific information can be accessed via {@link org.apache.spark.TaskContext#get}.
* <p>
* Plugin authors should avoid expensive operations here, as this method will be called
* on every task, and doing something expensive can significantly slow down a job.
* It is not recommended for a user to call a remote service, for example.
* <p>
* Exceptions thrown from this method do not propagate - they're caught,
* logged, and suppressed. Therefore exceptions when executing this method won't
* make the job fail.
*
* @since 3.1.0
*/
default void onTaskStart() {}

/**
* Perform an action after tasks completes without exceptions.
* <p>
* As {@link #onTaskStart() onTaskStart} exceptions are suppressed, this method
* will still be invoked even if the corresponding {@link #onTaskStart} call for this
* task failed.
* <p>
* Same warnings of {@link #onTaskStart() onTaskStart} apply here.
*
* @since 3.1.0
*/
default void onTaskSucceeded() {}

/**
* Perform an action after tasks completes with exceptions.
* <p>
* Same warnings of {@link #onTaskStart() onTaskStart} apply here.
*
* @param failureReason the exception thrown from the failed task.
*
* @since 3.1.0
*/
default void onTaskFailed(TaskFailedReason failureReason) {}
}
30 changes: 19 additions & 11 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ private[spark] class Executor(
private[executor] def numRunningTasks: Int = runningTasks.size()

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
val tr = new TaskRunner(context, taskDescription)
val tr = new TaskRunner(context, taskDescription, plugins)
runningTasks.put(taskDescription.taskId, tr)
threadPool.execute(tr)
}
Expand Down Expand Up @@ -301,7 +301,8 @@ private[spark] class Executor(

class TaskRunner(
execBackend: ExecutorBackend,
private val taskDescription: TaskDescription)
private val taskDescription: TaskDescription,
private val plugins: Option[PluginContainer])
extends Runnable {

val taskId = taskDescription.taskId
Expand Down Expand Up @@ -443,7 +444,8 @@ private[spark] class Executor(
taskAttemptId = taskId,
attemptNumber = taskDescription.attemptNumber,
metricsSystem = env.metricsSystem,
resources = taskDescription.resources)
resources = taskDescription.resources,
plugins = plugins)
threwException = false
res
} {
Expand Down Expand Up @@ -579,6 +581,7 @@ private[spark] class Executor(

executorSource.SUCCEEDED_TASKS.inc(1L)
setTaskFinishedAndClearInterruptStatus()
plugins.foreach(_.onTaskSucceeded())
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
case t: TaskKilledException =>
Expand All @@ -588,8 +591,9 @@ private[spark] class Executor(
// Here and below, put task metric peaks in a WrappedArray to expose them as a Seq
// without requiring a copy.
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
val serializedTK = ser.serialize(TaskKilled(t.reason, accUpdates, accums, metricPeaks))
execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
val reason = TaskKilled(t.reason, accUpdates, accums, metricPeaks)
plugins.foreach(_.onTaskFailed(reason))
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))

case _: InterruptedException | NonFatal(_) if
task != null && task.reasonIfKilled.isDefined =>
Expand All @@ -598,8 +602,9 @@ private[spark] class Executor(

val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
val serializedTK = ser.serialize(TaskKilled(killReason, accUpdates, accums, metricPeaks))
execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
val reason = TaskKilled(killReason, accUpdates, accums, metricPeaks)
plugins.foreach(_.onTaskFailed(reason))
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))

case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
val reason = task.context.fetchFailed.get.toTaskFailedReason
Expand All @@ -613,11 +618,13 @@ private[spark] class Executor(
s"other exception: $t")
}
setTaskFinishedAndClearInterruptStatus()
plugins.foreach(_.onTaskFailed(reason))
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

case CausedBy(cDE: CommitDeniedException) =>
val reason = cDE.toTaskCommitDeniedReason
setTaskFinishedAndClearInterruptStatus()
plugins.foreach(_.onTaskFailed(reason))
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))

case t: Throwable if env.isStopped =>
Expand All @@ -640,21 +647,22 @@ private[spark] class Executor(
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))

val serializedTaskEndReason = {
val (taskFailureReason, serializedTaskFailureReason) = {
try {
val ef = new ExceptionFailure(t, accUpdates).withAccums(accums)
.withMetricPeaks(metricPeaks)
ser.serialize(ef)
(ef, ser.serialize(ef))
} catch {
case _: NotSerializableException =>
// t is not serializable so just send the stacktrace
val ef = new ExceptionFailure(t, accUpdates, false).withAccums(accums)
.withMetricPeaks(metricPeaks)
ser.serialize(ef)
(ef, ser.serialize(ef))
}
}
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
plugins.foreach(_.onTaskFailed(taskFailureReason))
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskFailureReason)
} else {
logInfo("Not reporting error to driver during JVM shutdown.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.internal.plugin
import scala.collection.JavaConverters._
import scala.util.{Either, Left, Right}

import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.{SparkContext, SparkEnv, TaskFailedReason}
import org.apache.spark.api.plugin._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
Expand All @@ -31,6 +31,9 @@ sealed abstract class PluginContainer {

def shutdown(): Unit
def registerMetrics(appId: String): Unit
def onTaskStart(): Unit
def onTaskSucceeded(): Unit
def onTaskFailed(failureReason: TaskFailedReason): Unit

}

Expand Down Expand Up @@ -85,6 +88,17 @@ private class DriverPluginContainer(
}
}

override def onTaskStart(): Unit = {
throw new IllegalStateException("Should not be called for the driver container.")
}

override def onTaskSucceeded(): Unit = {
throw new IllegalStateException("Should not be called for the driver container.")
}

override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
throw new IllegalStateException("Should not be called for the driver container.")
}
}

private class ExecutorPluginContainer(
Expand Down Expand Up @@ -134,6 +148,39 @@ private class ExecutorPluginContainer(
}
}
}

override def onTaskStart(): Unit = {
executorPlugins.foreach { case (name, plugin) =>
try {
plugin.onTaskStart()
} catch {
case t: Throwable =>
logInfo(s"Exception while calling onTaskStart on plugin $name.", t)
}
}
}

override def onTaskSucceeded(): Unit = {
executorPlugins.foreach { case (name, plugin) =>
try {
plugin.onTaskSucceeded()
} catch {
case t: Throwable =>
logInfo(s"Exception while calling onTaskSucceeded on plugin $name.", t)
}
}
}

override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
executorPlugins.foreach { case (name, plugin) =>
try {
plugin.onTaskFailed(failureReason)
} catch {
case t: Throwable =>
logInfo(s"Exception while calling onTaskFailed on plugin $name.", t)
}
}
}
}

object PluginContainer {
Expand Down
6 changes: 5 additions & 1 deletion core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.Properties
import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.config.APP_CALLER_CONTEXT
import org.apache.spark.internal.plugin.PluginContainer
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.rdd.InputFileBlockHolder
Expand Down Expand Up @@ -82,7 +83,8 @@ private[spark] abstract class Task[T](
taskAttemptId: Long,
attemptNumber: Int,
metricsSystem: MetricsSystem,
resources: Map[String, ResourceInformation]): T = {
resources: Map[String, ResourceInformation],
plugins: Option[PluginContainer]): T = {
SparkEnv.get.blockManager.registerTask(taskAttemptId)
// TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
// the stage is barrier.
Expand Down Expand Up @@ -123,6 +125,8 @@ private[spark] abstract class Task[T](
Option(taskAttemptId),
Option(attemptNumber)).setCurrentContext()

plugins.foreach(_.onTaskStart())

try {
runTask(context)
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,40 @@ class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with Lo
assert(TestSparkPlugin.driverPlugin != null)
}

test("SPARK-33088: executor tasks trigger plugin calls") {
val conf = new SparkConf()
.setAppName(getClass().getName())
.set(SparkLauncher.SPARK_MASTER, "local[1]")
.set(PLUGINS, Seq(classOf[TestSparkPlugin].getName()))

sc = new SparkContext(conf)
sc.parallelize(1 to 10, 2).count()

assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2)
assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 2)
assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 0)
}

test("SPARK-33088: executor failed tasks trigger plugin calls") {
val conf = new SparkConf()
.setAppName(getClass().getName())
.set(SparkLauncher.SPARK_MASTER, "local[1]")
.set(PLUGINS, Seq(classOf[TestSparkPlugin].getName()))

sc = new SparkContext(conf)
try {
sc.parallelize(1 to 10, 2).foreach(i => throw new RuntimeException)
} catch {
case t: Throwable => // ignore exception
}

eventually(timeout(10.seconds), interval(100.millis)) {
assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2)
assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 0)
assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 2)
}
}

test("plugin initialization in non-local mode") {
val path = Utils.createTempDir()

Expand Down Expand Up @@ -309,13 +343,28 @@ private class TestDriverPlugin extends DriverPlugin {

private class TestExecutorPlugin extends ExecutorPlugin {

var numOnTaskStart: Int = 0
var numOnTaskSucceeded: Int = 0
var numOnTaskFailed: Int = 0

override def init(ctx: PluginContext, extraConf: JMap[String, String]): Unit = {
ctx.metricRegistry().register("executorMetric", new Gauge[Int] {
override def getValue(): Int = 84
})
TestSparkPlugin.executorContext = ctx
}

override def onTaskStart(): Unit = {
numOnTaskStart += 1
}

override def onTaskSucceeded(): Unit = {
numOnTaskSucceeded += 1
}

override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
numOnTaskFailed += 1
}
}

private object TestSparkPlugin {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties,
closureSerializer.serialize(TaskMetrics.registered).array())
intercept[RuntimeException] {
task.run(0, 0, null, null)
task.run(0, 0, null, null, Option.empty)
}
assert(TaskContextSuite.completed)
}
Expand All @@ -92,7 +92,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties,
closureSerializer.serialize(TaskMetrics.registered).array())
intercept[RuntimeException] {
task.run(0, 0, null, null)
task.run(0, 0, null, null, Option.empty)
}
assert(TaskContextSuite.lastError.getMessage == "damn error")
}
Expand Down

0 comments on commit 7739e49

Please sign in to comment.