Skip to content

Commit

Permalink
[SPARK-33173][CORE][TESTS][FOLLOWUP] Use local[2] and AtomicInteger
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Use `local[2]` to let tasks launch at the same time. And change counters (`numOnTaskXXX`) to `AtomicInteger` type to ensure thread safe.

### Why are the changes needed?

The test is still flaky after the fix #30072. See: https://github.com/apache/spark/pull/30728/checks?check_run_id=1557987642

And it's easy to reproduce if you test it multiple times (e.g. 100) locally.

The test sets up a stage with 2 tasks to run on an executor with 1 core. So these 2 tasks have to be launched one by one.
The task-2 will be launched after task-1 fails. However, since we don't retry failed task in local mode  (MAX_LOCAL_TASK_FAILURES = 1), the stage will abort right away after task-1 fail and cancels the running task-2 at the same time. There's a chance that task-2 gets canceled before calling `PluginContainer.onTaskStart`, which leads to the test failure.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Tested manually after the fix and the test is no longer flaky.

Closes #30823 from Ngone51/debug-flaky-spark-33088.

Authored-by: yi.wu <yi.wu@databricks.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
Ngone51 authored and dongjoon-hyun committed Dec 17, 2020
1 parent 8c81cf7 commit 15616f4
Showing 1 changed file with 14 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.internal.plugin
import java.io.File
import java.nio.charset.StandardCharsets
import java.util.{Map => JMap}
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.JavaConverters._
import scala.concurrent.duration._
Expand Down Expand Up @@ -138,15 +139,15 @@ class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with Lo
sc = new SparkContext(conf)
sc.parallelize(1 to 10, 2).count()

assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2)
assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 2)
assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 0)
assert(TestSparkPlugin.executorPlugin.numOnTaskStart.get() == 2)
assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded.get() == 2)
assert(TestSparkPlugin.executorPlugin.numOnTaskFailed.get() == 0)
}

test("SPARK-33088: executor failed tasks trigger plugin calls") {
val conf = new SparkConf()
.setAppName(getClass().getName())
.set(SparkLauncher.SPARK_MASTER, "local[1]")
.set(SparkLauncher.SPARK_MASTER, "local[2]")
.set(PLUGINS, Seq(classOf[TestSparkPlugin].getName()))

sc = new SparkContext(conf)
Expand All @@ -157,9 +158,9 @@ class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with Lo
}

eventually(timeout(10.seconds), interval(100.millis)) {
assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2)
assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 0)
assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 2)
assert(TestSparkPlugin.executorPlugin.numOnTaskStart.get() == 2)
assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded.get() == 0)
assert(TestSparkPlugin.executorPlugin.numOnTaskFailed.get() == 2)
}
}

Expand Down Expand Up @@ -343,9 +344,9 @@ private class TestDriverPlugin extends DriverPlugin {

private class TestExecutorPlugin extends ExecutorPlugin {

var numOnTaskStart: Int = 0
var numOnTaskSucceeded: Int = 0
var numOnTaskFailed: Int = 0
val numOnTaskStart = new AtomicInteger(0)
val numOnTaskSucceeded = new AtomicInteger(0)
val numOnTaskFailed = new AtomicInteger(0)

override def init(ctx: PluginContext, extraConf: JMap[String, String]): Unit = {
ctx.metricRegistry().register("executorMetric", new Gauge[Int] {
Expand All @@ -355,15 +356,15 @@ private class TestExecutorPlugin extends ExecutorPlugin {
}

override def onTaskStart(): Unit = {
numOnTaskStart += 1
numOnTaskStart.incrementAndGet()
}

override def onTaskSucceeded(): Unit = {
numOnTaskSucceeded += 1
numOnTaskSucceeded.incrementAndGet()
}

override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
numOnTaskFailed += 1
numOnTaskFailed.incrementAndGet()
}
}

Expand Down

0 comments on commit 15616f4

Please sign in to comment.