From 10254a94fb8f617d7392da813e4cc8936bafdfaf Mon Sep 17 00:00:00 2001 From: Samuel Souza Date: Thu, 8 Oct 2020 13:34:07 +0100 Subject: [PATCH 1/3] patch --- .../spark/api/plugin/ExecutorPlugin.java | 35 ++++++++++++++ .../org/apache/spark/executor/Executor.scala | 8 ++-- .../internal/plugin/PluginContainer.scala | 47 +++++++++++++++++++ .../org/apache/spark/scheduler/Task.scala | 11 ++++- .../plugin/PluginContainerSuite.scala | 47 +++++++++++++++++++ .../spark/scheduler/TaskContextSuite.scala | 4 +- 6 files changed, 145 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java b/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java index 4961308035163..a7fe71920e460 100644 --- a/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java +++ b/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java @@ -54,4 +54,39 @@ default void init(PluginContext ctx, Map extraConf) {} */ default void shutdown() {} + /** + * Perform any action before the task is run. + *

+ * This method is invoked from the same thread the task will be executed. + * Task-specific information can be accessed via {@link org.apache.spark.TaskContext#get}. + *

+ * Plugin authors should avoid expensive operations here, as this method will be called + * on every task, and doing something expensive can significantly slow down a job. + * It is not recommended for a user to call a remote service, for example. + *

+ * Exceptions thrown from this method do not propagate - they're caught, + * logged, and suppressed. Therefore exceptions when executing this method won't + * make the job fail. + */ + default void onTaskStart() {} + + /** + * Perform an action after tasks completes without exceptions. + *

+ * As {@link #onTaskStart() onTaskStart} exceptions are suppressed, this method + * will still be invoked even if the corresponding {@link #onTaskStart} call for this + * task failed. + *

+ * Same warnings of {@link #onTaskStart() onTaskStart} apply here. + */ + default void onTaskSucceeded() {} + + /** + * Perform an action after tasks completes with exceptions. + *

+ * Same warnings of {@link #onTaskStart() onTaskStart} apply here. + * + * @param failureReason the exception thrown from the failed task. + */ + default void onTaskFailed(Throwable failureReason) {} } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 27addd8fc12e2..b19642f74a94a 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -253,7 +253,7 @@ private[spark] class Executor( } def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { - val tr = new TaskRunner(context, taskDescription) + val tr = new TaskRunner(context, taskDescription, plugins) runningTasks.put(taskDescription.taskId, tr) threadPool.execute(tr) if (decommissioned) { @@ -332,7 +332,8 @@ private[spark] class Executor( class TaskRunner( execBackend: ExecutorBackend, - private val taskDescription: TaskDescription) + private val taskDescription: TaskDescription, + private val plugins: Option[PluginContainer]) extends Runnable { val taskId = taskDescription.taskId @@ -479,7 +480,8 @@ private[spark] class Executor( taskAttemptId = taskId, attemptNumber = taskDescription.attemptNumber, metricsSystem = env.metricsSystem, - resources = taskDescription.resources) + resources = taskDescription.resources, + plugins = plugins) threwException = false res } { diff --git a/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala b/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala index 4eda4767094ad..52bd5e7aac205 100644 --- a/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala +++ b/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala @@ -31,6 +31,9 @@ sealed abstract class PluginContainer { def shutdown(): Unit def registerMetrics(appId: String): Unit + def onTaskStart(): Unit + def onTaskSucceeded(): Unit + def onTaskFailed(failureReason: Throwable): Unit } @@ -85,6 +88,17 @@ private class DriverPluginContainer( } } + override def onTaskStart(): Unit = { + throw new IllegalStateException("Should not be called for the driver container.") + } + + override def onTaskSucceeded(): Unit = { + throw new IllegalStateException("Should not be called for the driver container.") + } + + override def onTaskFailed(throwable: Throwable): Unit = { + throw new IllegalStateException("Should not be called for the driver container.") + } } private class ExecutorPluginContainer( @@ -134,6 +148,39 @@ private class ExecutorPluginContainer( } } } + + override def onTaskStart(): Unit = { + executorPlugins.foreach { case (name, plugin) => + try { + plugin.onTaskStart() + } catch { + case t: Throwable => + logInfo(s"Exception while calling onTaskStart on plugin $name.", t) + } + } + } + + override def onTaskSucceeded(): Unit = { + executorPlugins.foreach { case (name, plugin) => + try { + plugin.onTaskSucceeded() + } catch { + case t: Throwable => + logInfo(s"Exception while calling onTaskSucceeded on plugin $name.", t) + } + } + } + + override def onTaskFailed(failureReason: Throwable): Unit = { + executorPlugins.foreach { case (name, plugin) => + try { + plugin.onTaskFailed(failureReason) + } catch { + case t: Throwable => + logInfo(s"Exception while calling onTaskFailed on plugin $name.", t) + } + } + } } object PluginContainer { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index ebc1c05435fee..c1320cfe14337 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -23,6 +23,7 @@ import java.util.Properties import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.config.APP_CALLER_CONTEXT +import org.apache.spark.internal.plugin.PluginContainer import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rdd.InputFileBlockHolder @@ -82,7 +83,8 @@ private[spark] abstract class Task[T]( taskAttemptId: Long, attemptNumber: Int, metricsSystem: MetricsSystem, - resources: Map[String, ResourceInformation]): T = { + resources: Map[String, ResourceInformation], + plugins: Option[PluginContainer]): T = { SparkEnv.get.blockManager.registerTask(taskAttemptId) // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether // the stage is barrier. @@ -123,8 +125,12 @@ private[spark] abstract class Task[T]( Option(taskAttemptId), Option(attemptNumber)).setCurrentContext() + plugins.foreach(_.onTaskStart()) + try { - runTask(context) + val taskResult = runTask(context) + plugins.foreach(_.onTaskSucceeded()) + taskResult } catch { case e: Throwable => // Catch all errors; run task failure callbacks, and rethrow the exception. @@ -135,6 +141,7 @@ private[spark] abstract class Task[T]( e.addSuppressed(t) } context.markTaskCompleted(Some(e)) + plugins.foreach(_.onTaskFailed(e)) throw e } finally { try { diff --git a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala index 7888796dd55e6..d23e79131cb2d 100644 --- a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala @@ -129,6 +129,38 @@ class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with Lo assert(TestSparkPlugin.driverPlugin != null) } + test("SPARK-33088: executor tasks trigger plugin calls") { + val conf = new SparkConf() + .setAppName(getClass().getName()) + .set(SparkLauncher.SPARK_MASTER, "local[1]") + .set(PLUGINS, Seq(classOf[TestSparkPlugin].getName())) + + sc = new SparkContext(conf) + sc.parallelize(1 to 10, 2).count() + + assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2) + assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 2) + assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 0) + } + + test("SPARK-33088: executor failed tasks trigger plugin calls") { + val conf = new SparkConf() + .setAppName(getClass().getName()) + .set(SparkLauncher.SPARK_MASTER, "local[1]") + .set(PLUGINS, Seq(classOf[TestSparkPlugin].getName())) + + sc = new SparkContext(conf) + try { + sc.parallelize(1 to 10, 2).foreach(i => throw new RuntimeException) + } catch { + case t: Throwable => // ignore exception + } + + assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2) + assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 0) + assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 2) + } + test("plugin initialization in non-local mode") { val path = Utils.createTempDir() @@ -309,6 +341,10 @@ private class TestDriverPlugin extends DriverPlugin { private class TestExecutorPlugin extends ExecutorPlugin { + var numOnTaskStart: Int = 0 + var numOnTaskSucceeded: Int = 0 + var numOnTaskFailed: Int = 0 + override def init(ctx: PluginContext, extraConf: JMap[String, String]): Unit = { ctx.metricRegistry().register("executorMetric", new Gauge[Int] { override def getValue(): Int = 84 @@ -316,6 +352,17 @@ private class TestExecutorPlugin extends ExecutorPlugin { TestSparkPlugin.executorContext = ctx } + override def onTaskStart(): Unit = { + numOnTaskStart += 1 + } + + override def onTaskSucceeded(): Unit = { + numOnTaskSucceeded += 1 + } + + override def onTaskFailed(failureReason: Throwable): Unit = { + numOnTaskFailed += 1 + } } private object TestSparkPlugin { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 394a2a9fbf7cb..8a7ff9eb6dcd3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -70,7 +70,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, closureSerializer.serialize(TaskMetrics.registered).array()) intercept[RuntimeException] { - task.run(0, 0, null, null) + task.run(0, 0, null, null, Option.empty) } assert(TaskContextSuite.completed) } @@ -92,7 +92,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, closureSerializer.serialize(TaskMetrics.registered).array()) intercept[RuntimeException] { - task.run(0, 0, null, null) + task.run(0, 0, null, null, Option.empty) } assert(TaskContextSuite.lastError.getMessage == "damn error") } From 0c0700603eec7be836c60d518940289642fd5593 Mon Sep 17 00:00:00 2001 From: Samuel Souza Date: Wed, 14 Oct 2020 16:20:54 +0100 Subject: [PATCH 2/3] add @Since --- .../java/org/apache/spark/api/plugin/ExecutorPlugin.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java b/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java index a7fe71920e460..65ba97be50676 100644 --- a/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java +++ b/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java @@ -67,6 +67,8 @@ default void shutdown() {} * Exceptions thrown from this method do not propagate - they're caught, * logged, and suppressed. Therefore exceptions when executing this method won't * make the job fail. + * + * @since 3.1.0 */ default void onTaskStart() {} @@ -78,6 +80,8 @@ default void onTaskStart() {} * task failed. *

* Same warnings of {@link #onTaskStart() onTaskStart} apply here. + * + * @since 3.1.0 */ default void onTaskSucceeded() {} @@ -87,6 +91,8 @@ default void onTaskSucceeded() {} * Same warnings of {@link #onTaskStart() onTaskStart} apply here. * * @param failureReason the exception thrown from the failed task. + * + * @since 3.1.0 */ default void onTaskFailed(Throwable failureReason) {} } From 8a5e43671b3ffc16a6d3630886027cde97380558 Mon Sep 17 00:00:00 2001 From: Samuel Souza Date: Thu, 15 Oct 2020 11:29:15 +0100 Subject: [PATCH 3/3] task -> executor --- .../spark/api/plugin/ExecutorPlugin.java | 3 ++- .../org/apache/spark/executor/Executor.scala | 24 +++++++++++-------- .../internal/plugin/PluginContainer.scala | 8 +++---- .../org/apache/spark/scheduler/Task.scala | 5 +--- .../plugin/PluginContainerSuite.scala | 2 +- 5 files changed, 22 insertions(+), 20 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java b/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java index 65ba97be50676..481bf985f1c6c 100644 --- a/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java +++ b/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java @@ -19,6 +19,7 @@ import java.util.Map; +import org.apache.spark.TaskFailedReason; import org.apache.spark.annotation.DeveloperApi; /** @@ -94,5 +95,5 @@ default void onTaskSucceeded() {} * * @since 3.1.0 */ - default void onTaskFailed(Throwable failureReason) {} + default void onTaskFailed(TaskFailedReason failureReason) {} } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index b19642f74a94a..6653650615192 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -616,6 +616,7 @@ private[spark] class Executor( executorSource.SUCCEEDED_TASKS.inc(1L) setTaskFinishedAndClearInterruptStatus() + plugins.foreach(_.onTaskSucceeded()) execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { case t: TaskKilledException => @@ -625,9 +626,9 @@ private[spark] class Executor( // Here and below, put task metric peaks in a WrappedArray to expose them as a Seq // without requiring a copy. val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId)) - val serializedTK = ser.serialize( - TaskKilled(t.reason, accUpdates, accums, metricPeaks.toSeq)) - execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) + val reason = TaskKilled(t.reason, accUpdates, accums, metricPeaks.toSeq) + plugins.foreach(_.onTaskFailed(reason)) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason)) case _: InterruptedException | NonFatal(_) if task != null && task.reasonIfKilled.isDefined => @@ -636,9 +637,9 @@ private[spark] class Executor( val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs) val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId)) - val serializedTK = ser.serialize( - TaskKilled(killReason, accUpdates, accums, metricPeaks.toSeq)) - execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) + val reason = TaskKilled(killReason, accUpdates, accums, metricPeaks.toSeq) + plugins.foreach(_.onTaskFailed(reason)) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason)) case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => val reason = task.context.fetchFailed.get.toTaskFailedReason @@ -652,11 +653,13 @@ private[spark] class Executor( s"other exception: $t") } setTaskFinishedAndClearInterruptStatus() + plugins.foreach(_.onTaskFailed(reason)) execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskCommitDeniedReason setTaskFinishedAndClearInterruptStatus() + plugins.foreach(_.onTaskFailed(reason)) execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason)) case t: Throwable if env.isStopped => @@ -679,21 +682,22 @@ private[spark] class Executor( val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs) val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId)) - val serializedTaskEndReason = { + val (taskFailureReason, serializedTaskFailureReason) = { try { val ef = new ExceptionFailure(t, accUpdates).withAccums(accums) .withMetricPeaks(metricPeaks.toSeq) - ser.serialize(ef) + (ef, ser.serialize(ef)) } catch { case _: NotSerializableException => // t is not serializable so just send the stacktrace val ef = new ExceptionFailure(t, accUpdates, false).withAccums(accums) .withMetricPeaks(metricPeaks.toSeq) - ser.serialize(ef) + (ef, ser.serialize(ef)) } } setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason) + plugins.foreach(_.onTaskFailed(taskFailureReason)) + execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskFailureReason) } else { logInfo("Not reporting error to driver during JVM shutdown.") } diff --git a/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala b/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala index 52bd5e7aac205..f78ec250f7173 100644 --- a/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala +++ b/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala @@ -20,7 +20,7 @@ package org.apache.spark.internal.plugin import scala.collection.JavaConverters._ import scala.util.{Either, Left, Right} -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.{SparkContext, SparkEnv, TaskFailedReason} import org.apache.spark.api.plugin._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -33,7 +33,7 @@ sealed abstract class PluginContainer { def registerMetrics(appId: String): Unit def onTaskStart(): Unit def onTaskSucceeded(): Unit - def onTaskFailed(failureReason: Throwable): Unit + def onTaskFailed(failureReason: TaskFailedReason): Unit } @@ -96,7 +96,7 @@ private class DriverPluginContainer( throw new IllegalStateException("Should not be called for the driver container.") } - override def onTaskFailed(throwable: Throwable): Unit = { + override def onTaskFailed(failureReason: TaskFailedReason): Unit = { throw new IllegalStateException("Should not be called for the driver container.") } } @@ -171,7 +171,7 @@ private class ExecutorPluginContainer( } } - override def onTaskFailed(failureReason: Throwable): Unit = { + override def onTaskFailed(failureReason: TaskFailedReason): Unit = { executorPlugins.foreach { case (name, plugin) => try { plugin.onTaskFailed(failureReason) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index c1320cfe14337..81f984bb2b511 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -128,9 +128,7 @@ private[spark] abstract class Task[T]( plugins.foreach(_.onTaskStart()) try { - val taskResult = runTask(context) - plugins.foreach(_.onTaskSucceeded()) - taskResult + runTask(context) } catch { case e: Throwable => // Catch all errors; run task failure callbacks, and rethrow the exception. @@ -141,7 +139,6 @@ private[spark] abstract class Task[T]( e.addSuppressed(t) } context.markTaskCompleted(Some(e)) - plugins.foreach(_.onTaskFailed(e)) throw e } finally { try { diff --git a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala index d23e79131cb2d..e7fbe5b998a88 100644 --- a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala @@ -360,7 +360,7 @@ private class TestExecutorPlugin extends ExecutorPlugin { numOnTaskSucceeded += 1 } - override def onTaskFailed(failureReason: Throwable): Unit = { + override def onTaskFailed(failureReason: TaskFailedReason): Unit = { numOnTaskFailed += 1 } }