Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-33088][CORE] Enhance ExecutorPlugin API to include callbacks on task start and end events #29977

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.util.Map;

import org.apache.spark.TaskFailedReason;
import org.apache.spark.annotation.DeveloperApi;

/**
Expand Down Expand Up @@ -54,4 +55,45 @@ default void init(PluginContext ctx, Map<String, String> extraConf) {}
*/
default void shutdown() {}

/**
* Perform any action before the task is run.
* <p>
* This method is invoked from the same thread the task will be executed.
* Task-specific information can be accessed via {@link org.apache.spark.TaskContext#get}.
* <p>
* Plugin authors should avoid expensive operations here, as this method will be called
* on every task, and doing something expensive can significantly slow down a job.
* It is not recommended for a user to call a remote service, for example.
* <p>
* Exceptions thrown from this method do not propagate - they're caught,
* logged, and suppressed. Therefore exceptions when executing this method won't
* make the job fail.
*
* @since 3.1.0
*/
fsamuel-bs marked this conversation as resolved.
Show resolved Hide resolved
default void onTaskStart() {}

/**
* Perform an action after tasks completes without exceptions.
* <p>
* As {@link #onTaskStart() onTaskStart} exceptions are suppressed, this method
* will still be invoked even if the corresponding {@link #onTaskStart} call for this
* task failed.
* <p>
* Same warnings of {@link #onTaskStart() onTaskStart} apply here.
*
* @since 3.1.0
*/
default void onTaskSucceeded() {}

/**
* Perform an action after tasks completes with exceptions.
* <p>
* Same warnings of {@link #onTaskStart() onTaskStart} apply here.
*
* @param failureReason the exception thrown from the failed task.
*
* @since 3.1.0
*/
default void onTaskFailed(TaskFailedReason failureReason) {}
}
32 changes: 19 additions & 13 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ private[spark] class Executor(
}

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
val tr = new TaskRunner(context, taskDescription)
val tr = new TaskRunner(context, taskDescription, plugins)
runningTasks.put(taskDescription.taskId, tr)
threadPool.execute(tr)
if (decommissioned) {
Expand Down Expand Up @@ -332,7 +332,8 @@ private[spark] class Executor(

class TaskRunner(
execBackend: ExecutorBackend,
private val taskDescription: TaskDescription)
private val taskDescription: TaskDescription,
private val plugins: Option[PluginContainer])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can make the plugins parameter optional or default to some EmptyPluginContainer?:

Suggested change
private val plugins: Option[PluginContainer])
private val plugins: Option[PluginContainer] = None)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for Task#run.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rshkv, what is the reason to make this default to None? This is an internal api and only called from here. It's an option already so people can check it easily. In some ways its nice to force it so you make sure all uses of it have been updated.
Are there cases you know this is used outside Spark?

extends Runnable {

val taskId = taskDescription.taskId
Expand Down Expand Up @@ -479,7 +480,8 @@ private[spark] class Executor(
taskAttemptId = taskId,
attemptNumber = taskDescription.attemptNumber,
metricsSystem = env.metricsSystem,
resources = taskDescription.resources)
resources = taskDescription.resources,
plugins = plugins)
threwException = false
res
} {
Expand Down Expand Up @@ -614,6 +616,7 @@ private[spark] class Executor(

executorSource.SUCCEEDED_TASKS.inc(1L)
setTaskFinishedAndClearInterruptStatus()
plugins.foreach(_.onTaskSucceeded())
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
case t: TaskKilledException =>
Expand All @@ -623,9 +626,9 @@ private[spark] class Executor(
// Here and below, put task metric peaks in a WrappedArray to expose them as a Seq
// without requiring a copy.
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
val serializedTK = ser.serialize(
TaskKilled(t.reason, accUpdates, accums, metricPeaks.toSeq))
execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
val reason = TaskKilled(t.reason, accUpdates, accums, metricPeaks.toSeq)
plugins.foreach(_.onTaskFailed(reason))
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))

case _: InterruptedException | NonFatal(_) if
task != null && task.reasonIfKilled.isDefined =>
Expand All @@ -634,9 +637,9 @@ private[spark] class Executor(

val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
val serializedTK = ser.serialize(
TaskKilled(killReason, accUpdates, accums, metricPeaks.toSeq))
execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
val reason = TaskKilled(killReason, accUpdates, accums, metricPeaks.toSeq)
plugins.foreach(_.onTaskFailed(reason))
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))

case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
val reason = task.context.fetchFailed.get.toTaskFailedReason
Expand All @@ -650,11 +653,13 @@ private[spark] class Executor(
s"other exception: $t")
}
setTaskFinishedAndClearInterruptStatus()
plugins.foreach(_.onTaskFailed(reason))
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

case CausedBy(cDE: CommitDeniedException) =>
val reason = cDE.toTaskCommitDeniedReason
setTaskFinishedAndClearInterruptStatus()
plugins.foreach(_.onTaskFailed(reason))
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))

case t: Throwable if env.isStopped =>
Expand All @@ -677,21 +682,22 @@ private[spark] class Executor(
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))

val serializedTaskEndReason = {
val (taskFailureReason, serializedTaskFailureReason) = {
try {
val ef = new ExceptionFailure(t, accUpdates).withAccums(accums)
.withMetricPeaks(metricPeaks.toSeq)
ser.serialize(ef)
(ef, ser.serialize(ef))
} catch {
case _: NotSerializableException =>
// t is not serializable so just send the stacktrace
val ef = new ExceptionFailure(t, accUpdates, false).withAccums(accums)
.withMetricPeaks(metricPeaks.toSeq)
ser.serialize(ef)
(ef, ser.serialize(ef))
}
}
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
plugins.foreach(_.onTaskFailed(taskFailureReason))
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskFailureReason)
} else {
logInfo("Not reporting error to driver during JVM shutdown.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.internal.plugin
import scala.collection.JavaConverters._
import scala.util.{Either, Left, Right}

import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.{SparkContext, SparkEnv, TaskFailedReason}
import org.apache.spark.api.plugin._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
Expand All @@ -31,6 +31,9 @@ sealed abstract class PluginContainer {

def shutdown(): Unit
def registerMetrics(appId: String): Unit
def onTaskStart(): Unit
def onTaskSucceeded(): Unit
def onTaskFailed(failureReason: TaskFailedReason): Unit

}

Expand Down Expand Up @@ -85,6 +88,17 @@ private class DriverPluginContainer(
}
}

override def onTaskStart(): Unit = {
throw new IllegalStateException("Should not be called for the driver container.")
}

override def onTaskSucceeded(): Unit = {
throw new IllegalStateException("Should not be called for the driver container.")
}

override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
throw new IllegalStateException("Should not be called for the driver container.")
}
}

private class ExecutorPluginContainer(
Expand Down Expand Up @@ -134,6 +148,39 @@ private class ExecutorPluginContainer(
}
}
}

override def onTaskStart(): Unit = {
executorPlugins.foreach { case (name, plugin) =>
try {
plugin.onTaskStart()
} catch {
case t: Throwable =>
logInfo(s"Exception while calling onTaskStart on plugin $name.", t)
}
}
}

override def onTaskSucceeded(): Unit = {
executorPlugins.foreach { case (name, plugin) =>
try {
plugin.onTaskSucceeded()
} catch {
case t: Throwable =>
logInfo(s"Exception while calling onTaskSucceeded on plugin $name.", t)
}
}
}

override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
executorPlugins.foreach { case (name, plugin) =>
try {
plugin.onTaskFailed(failureReason)
} catch {
case t: Throwable =>
logInfo(s"Exception while calling onTaskFailed on plugin $name.", t)
}
}
}
}

object PluginContainer {
Expand Down
6 changes: 5 additions & 1 deletion core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.Properties
import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.config.APP_CALLER_CONTEXT
import org.apache.spark.internal.plugin.PluginContainer
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.rdd.InputFileBlockHolder
Expand Down Expand Up @@ -82,7 +83,8 @@ private[spark] abstract class Task[T](
taskAttemptId: Long,
attemptNumber: Int,
metricsSystem: MetricsSystem,
resources: Map[String, ResourceInformation]): T = {
resources: Map[String, ResourceInformation],
plugins: Option[PluginContainer]): T = {
SparkEnv.get.blockManager.registerTask(taskAttemptId)
// TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
// the stage is barrier.
Expand Down Expand Up @@ -123,6 +125,8 @@ private[spark] abstract class Task[T](
Option(taskAttemptId),
Option(attemptNumber)).setCurrentContext()

plugins.foreach(_.onTaskStart())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expectation in case onTaskStart fails - do we want to invoke succeeded/failed ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I documented on https://github.com/apache/spark/pull/29977/files#diff-6a99ec9983962323b4e0c1899134b5d6R76-R78 -- argument that came to mind is that it's easy for a plugin dev to track some state in a thread-local and clean decide if it wants to perform the succeeded/failed action or not.

Happy to change it if we prefer not to put this burden on the plugin owner though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe I'm misunderstanding but the documentation states "Exceptions thrown from this method do not propagate", there is nothing here preventing that. I think perhaps you meant to say the user needs to make sure they don't propagate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We catch Throwable on ExecutorPluginContainer#onTaskStart and siblings (see https://github.com/apache/spark/pull/29977/files#diff-5e4d939e9bb53b4be2c48d4eb53b885c162c729b9adc874f918f7701a352cdbbR157), so that's what I meant by "not propagate". I.e. if a plugin's onTaskStart throws, Spark will log, but won't fail the associated spark task.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps reword it to say exceptions are ignored ?


try {
runTask(context)
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,38 @@ class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with Lo
assert(TestSparkPlugin.driverPlugin != null)
}

test("SPARK-33088: executor tasks trigger plugin calls") {
val conf = new SparkConf()
.setAppName(getClass().getName())
.set(SparkLauncher.SPARK_MASTER, "local[1]")
.set(PLUGINS, Seq(classOf[TestSparkPlugin].getName()))

sc = new SparkContext(conf)
sc.parallelize(1 to 10, 2).count()

assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2)
assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 2)
assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 0)
}

test("SPARK-33088: executor failed tasks trigger plugin calls") {
val conf = new SparkConf()
.setAppName(getClass().getName())
.set(SparkLauncher.SPARK_MASTER, "local[1]")
.set(PLUGINS, Seq(classOf[TestSparkPlugin].getName()))

sc = new SparkContext(conf)
try {
sc.parallelize(1 to 10, 2).foreach(i => throw new RuntimeException)
} catch {
case t: Throwable => // ignore exception
}

assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2)
assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 0)
assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 2)
Copy link
Member

@dongjoon-hyun dongjoon-hyun Oct 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, folks.
It turns out that this is a flaky test. I filed a JIRA issue and made PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

test("plugin initialization in non-local mode") {
val path = Utils.createTempDir()

Expand Down Expand Up @@ -309,13 +341,28 @@ private class TestDriverPlugin extends DriverPlugin {

private class TestExecutorPlugin extends ExecutorPlugin {

var numOnTaskStart: Int = 0
var numOnTaskSucceeded: Int = 0
var numOnTaskFailed: Int = 0

override def init(ctx: PluginContext, extraConf: JMap[String, String]): Unit = {
ctx.metricRegistry().register("executorMetric", new Gauge[Int] {
override def getValue(): Int = 84
})
TestSparkPlugin.executorContext = ctx
}

override def onTaskStart(): Unit = {
numOnTaskStart += 1
}

override def onTaskSucceeded(): Unit = {
numOnTaskSucceeded += 1
}

override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
numOnTaskFailed += 1
}
}

private object TestSparkPlugin {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties,
closureSerializer.serialize(TaskMetrics.registered).array())
intercept[RuntimeException] {
task.run(0, 0, null, null)
task.run(0, 0, null, null, Option.empty)
}
assert(TaskContextSuite.completed)
}
Expand All @@ -92,7 +92,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties,
closureSerializer.serialize(TaskMetrics.registered).array())
intercept[RuntimeException] {
task.run(0, 0, null, null)
task.run(0, 0, null, null, Option.empty)
}
assert(TaskContextSuite.lastError.getMessage == "damn error")
}
Expand Down