diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4dfd532e9362..9e6cc85cbae3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -252,6 +252,13 @@ class DAGScheduler( eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception)) } + /** + * Called by the TaskSetManager when a set of tasks are aborted due to fetch failure. + */ + def tasksAborted(stageId: Int, tasks: Seq[Task[_]]): Unit = { + eventProcessLoop.post(TasksAborted(stageId, tasks)) + } + private[scheduler] def getCacheLocs(rdd: RDD[_]): IndexedSeq[Seq[TaskLocation]] = cacheLocs.synchronized { // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times @@ -810,6 +817,16 @@ class DAGScheduler( submitWaitingStages() } + private[scheduler] def handleTasksAborted( + stageId: Int, + tasks: Seq[Task[_]]): Unit = { + for (stage <- stageIdToStage.get(stageId)) { + for (task <- tasks) { + stage.pendingPartitions -= task.partitionId + } + } + } + private[scheduler] def cleanUpAfterSchedulerStop() { for (job <- activeJobs) { val error = @@ -941,14 +958,22 @@ class DAGScheduler( } } - /** Called when stage's parents are available and we can now do its task. */ + /** + * Called when stage's parents are available and we can now run its task. + * This only submits the partitions which are missing and have not been + * submitted to the lower-level scheduler for execution. + */ private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") - // Get our pending tasks and remember them in our pendingTasks entry - stage.pendingPartitions.clear() - // First figure out the indexes of partition ids to compute. - val partitionsToCompute: Seq[Int] = stage.findMissingPartitions() + val missingPartitions = stage.findMissingPartitions() + val partitionsToCompute = + missingPartitions.filter(id => !stage.pendingPartitions.contains(id)) + stage.pendingPartitions ++= partitionsToCompute + + if (partitionsToCompute.isEmpty) { + return + } // Use the scheduling pool, job group, description, etc. from an ActiveJob associated // with this Stage @@ -971,7 +996,6 @@ class DAGScheduler( case s: ShuffleMapStage => partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap case s: ResultStage => - val job = s.activeJob.get partitionsToCompute.map { id => val p = s.partitions(id) (id, getPreferredLocs(stage.rdd, p)) @@ -1051,7 +1075,6 @@ class DAGScheduler( if (tasks.size > 0) { logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") - stage.pendingPartitions ++= tasks.map(_.partitionId) logDebug("New pending partitions: " + stage.pendingPartitions) taskScheduler.submitTasks(new TaskSet( tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) @@ -1153,9 +1176,9 @@ class DAGScheduler( } val stage = stageIdToStage(task.stageId) + stage.pendingPartitions -= task.partitionId event.reason match { case Success => - stage.pendingPartitions -= task.partitionId task match { case rt: ResultTask[_, _] => // Cast to ResultStage here because it's part of the ResultTask @@ -1173,6 +1196,12 @@ class DAGScheduler( cleanupStateForJobAndIndependentStages(job) listenerBus.post( SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) + } else if (resultStage.pendingPartitions.isEmpty) { + logInfo("Resubmitting " + resultStage + " (" + resultStage.name + + ") because some of its tasks had failed: " + + resultStage.findMissingPartitions().mkString(", ")) + markStageAsFinished(resultStage) + submitStage(resultStage) } // taskSucceeded runs some user code that might throw an exception. Make sure @@ -1250,11 +1279,6 @@ class DAGScheduler( val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleToMapStage(shuffleId) - if (failedStage.latestInfo.attemptId != task.stageAttemptId) { - logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + - s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + - s"(attempt ID ${failedStage.latestInfo.attemptId}) running") - } else { // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is // possible the fetch failure has already been handled by the scheduler. @@ -1278,7 +1302,6 @@ class DAGScheduler( } else if (failedStages.isEmpty) { // Don't schedule an event to resubmit failed stages if failed isn't empty, because // in that case the event will already have been scheduled. - // TODO: Cancel running tasks in the stage logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + s"$failedStage (${failedStage.name}) due to fetch failure") messageScheduler.schedule(new Runnable { @@ -1297,7 +1320,6 @@ class DAGScheduler( if (bmAddress != null) { handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) } - } case commitDenied: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits @@ -1656,6 +1678,9 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case TaskSetFailed(taskSet, reason, exception) => dagScheduler.handleTaskSetFailed(taskSet, reason, exception) + case TasksAborted(stageId, tasks) => + dagScheduler.handleTasksAborted(stageId, tasks) + case ResubmitFailedStages => dagScheduler.resubmitFailedStages() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 8c761124824a..0e784c0c044e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -83,4 +83,7 @@ private[scheduler] case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]) extends DAGSchedulerEvent +private[scheduler] +case class TasksAborted(stageId: Int, tasks: Seq[Task[_]]) extends DAGSchedulerEvent + private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 2f972b064b47..21915da79ddc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import scala.collection.mutable import scala.collection.mutable.HashSet import org.apache.spark._ @@ -68,6 +69,10 @@ private[scheduler] abstract class Stage( /** Set of jobs that this stage belongs to. */ val jobIds = new HashSet[Int] + /** + * Set of partitions which have been submitted to the lower-level scheduler and + * they should not be resubmitted when rerun of the stage. + */ val pendingPartitions = new HashSet[Int] /** The ID to use for the next new attempt for this stage. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index bfa1e86749a4..2a8f3ef0464b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -106,8 +106,8 @@ private[spark] class TaskSetManager( // the zombie state once at least one attempt of each task has completed successfully, or if the // task set is aborted (for example, because it was killed). TaskSetManagers remain in the zombie // state until all tasks have finished running; we keep TaskSetManagers that are in the zombie - // state in order to continue to track and account for the running tasks. - // TODO: We should kill any running task attempts when the task set manager becomes a zombie. + // state in order to continue to track and account for the running tasks. The tasks running in the + // zombie TaskSetManagers are not rerun by the DagScheduler unless they fail. var isZombie = false // Set of pending tasks for each executor. These collections are actually @@ -652,13 +652,30 @@ private[spark] class TaskSetManager( reason.asInstanceOf[TaskFailedReason].toErrorString val failureException: Option[Throwable] = reason match { case fetchFailed: FetchFailed => + if (!isZombie) { + // Only for the first occurrence of the fetch failure, get the list + // of all non-running and non-successful tasks and notify the + // DagScheduler of their abortion so that they can be rescheduled in retry + // of the stage. Note that this does not include the fetch failed tasks, + // because that is separately handled by the DagScheduler. + val abortedTasks = new ArrayBuffer[Task[_]] + for (i <- 0 until numTasks) { + if (i != index && !successful(i) && copiesRunning(i) == 0) { + abortedTasks += taskSet.tasks(i) + } + } + if (!abortedTasks.isEmpty) { + sched.dagScheduler.tasksAborted(abortedTasks(0).stageId, abortedTasks) + } + isZombie = true + } + logWarning(failureReason) if (!successful(index)) { successful(index) = true tasksSuccessful += 1 } // Not adding to failed executors for FetchFailed. - isZombie = true None case ef: ExceptionFailure => diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 844c780a3fdd..9dfa4284162d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -833,6 +833,219 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assert(results === Map(0 -> 42)) } + test("Test no duplicate shuffle map tasks running on fetch failure (SPARK-14649)") { + val firstRDD = new MyRDD(sc, 2, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2)) + val firstShuffleId = firstShuffleDep.shuffleId + val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + + // things start out smoothly, stage 0 completes with no issues + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostB", shuffleMapRdd.partitions.length)), + (Success, makeMapStatus("hostA", shuffleMapRdd.partitions.length)) + )) + + // Begin event for the reduce tasks. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + // fail taks 0 in stage 1 due to fetch failure + val failedTask1 = taskSets(1).tasks(0) + runEvent(makeCompletionEvent( + failedTask1, + FetchFailed(makeBlockManagerId("hostA"), firstShuffleId, 0, 0, "ignored"), + 42, Seq.empty, createFakeTaskInfo())) + + // abort task 1 in stage 1 due to fetch failure, task 2 still running + val abortedTask1 = taskSets(1).tasks(1) + runEvent(new TasksAborted(1, List(abortedTask1))) + + // Make sure that we still have 2 running tasks for the first attempt + assert(sparkListener.failedStages.contains(1)) + + // Wait for resubmission of the map stage + Thread.sleep(1000) + + // so we resubmit stage 0, which completes happily + val stage0Resubmit1 = taskSets(2) + assert(stage0Resubmit1.stageId == 0) + assert(stage0Resubmit1.stageAttemptId === 1) + val task1 = stage0Resubmit1.tasks(0) + runEvent(makeCompletionEvent( + task1, + Success, + makeMapStatus("hostC", shuffleMapRdd.partitions.length), + Seq.empty, createFakeTaskInfo())) + + // we will now have a task set representing + // the second attempt for stage 1, but we *also* have 1 task for the first attempt for + // stage 1 still going, so we make sure that we don't resubmit the already running tasks. + val stage1Resubmit1 = taskSets(3) + assert(stage1Resubmit1.stageId == 1) + assert(stage1Resubmit1.stageAttemptId === 1) + assert(stage1Resubmit1.tasks.length === 2) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + // Now fail the running task from the first attempt and + // succeed all others + val succeededTask1 = taskSets(3).tasks(0) + runEvent(makeCompletionEvent( + succeededTask1, + Success, + makeMapStatus("hostC", reduceRdd.partitions.length), + Seq.empty, createFakeTaskInfo())) + + val failedTask2 = taskSets(1).tasks(2) + runEvent(makeCompletionEvent( + failedTask2, + FetchFailed(makeBlockManagerId("hostB"), firstShuffleId, 0, 0, "ignored"), + 42, Seq.empty, createFakeTaskInfo())) + + val succeededTask2 = taskSets(3).tasks(1) + runEvent(makeCompletionEvent( + succeededTask2, + Success, + makeMapStatus("hostC", reduceRdd.partitions.length), + Seq.empty, createFakeTaskInfo())) + + // Sleep for some time for the completion event to be processed by the DagScheduler + // and make sure that stage 0 is resumbitted. + Thread.sleep(1000) + val stage0Resubmit2 = taskSets(4) + assert(stage0Resubmit2.stageId == 0) + assert(stage0Resubmit2.stageAttemptId === 2) + assert(stage0Resubmit2.tasks.length === 1) + val task2 = stage0Resubmit2.tasks(0) + runEvent(makeCompletionEvent( + task2, + Success, + makeMapStatus("hostC", shuffleMapRdd.partitions.length))) + + Thread.sleep(1000) + // Make sure we resubmit the only failed tasks in stage 1. + val stage1Resubmit2 = taskSets(5) + assert(stage1Resubmit2.stageId == 1) + assert(stage1Resubmit2.stageAttemptId === 2) + assert(stage1Resubmit2.tasks.length === 1) + } + + test("Test no duplicate result tasks running on fetch failure (SPARK-14649)") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 3, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1, 2)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) + // Begin event for the reduce tasks. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + // Fail one and abort one of the reduce tasks due to fetch failure. + assert(taskSets(1).tasks.size === 3) + val failedTask = taskSets(1).tasks(0) + runEvent(makeCompletionEvent( + failedTask, + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + 42, Seq.empty, createFakeTaskInfo())) + val abortedTask = taskSets(1).tasks(1) + runEvent(new TasksAborted(1, List(abortedTask))) + // note that taskSet(1).tasks(2) will be still running state so + // it should not be resumbitted in the next retry + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sparkListener.failedStages.contains(1)) + + // Wait for resubmission of the map stage + Thread.sleep(1000) + // Retry of the map stage finishes happily + assert(taskSets(2).tasks.size === 1) + complete(taskSets(2), Seq( + (Success, makeMapStatus("hostC", reduceRdd.partitions.length)))) + + // Newly submitted taskSet for the reduce phase should not contain the + // running task. + assert(taskSets(3).tasks.size == 2) + + // Finish the taskSet(3) successfully + runEvent(makeCompletionEvent( + taskSets(3).tasks(0), Success, 42, + Seq.empty, createFakeTaskInfo())) + + runEvent(makeCompletionEvent( + taskSets(3).tasks(1), Success, 42, + Seq.empty, createFakeTaskInfo())) + + // Fail the running tasks in taskSets(1) and make sure its resubmitted + runEvent(makeCompletionEvent( + taskSets(1).tasks(2), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + 42, Seq.empty, createFakeTaskInfo())) + + // Make sure that we resubmit the failed reduce task + Thread.sleep(1000) + assert(taskSets(4).tasks.size == 1) + } + + test("Test task rerun in case of failure in zombie taskSet (SPARK-14649)") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 3, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1, 2)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) + // Begin event for the reduce tasks. + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + // Fail one and abort one of the reduce tasks due to fetch failure. + assert(taskSets(1).tasks.size === 3) + val failedTask = taskSets(1).tasks(0) + runEvent(makeCompletionEvent( + failedTask, + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + 42, Seq.empty, createFakeTaskInfo())) + val abortedTask = taskSets(1).tasks(1) + runEvent(new TasksAborted(1, List(abortedTask))) + // note that taskSet(1).tasks(2) will be still running state so + // it should not be resumbitted in the next retry + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sparkListener.failedStages.contains(1)) + + // Wait for resubmission of the map stage + Thread.sleep(1000) + // Retry of the map stage finishes happily + assert(taskSets(2).tasks.size === 1) + complete(taskSets(2), Seq( + (Success, makeMapStatus("hostC", reduceRdd.partitions.length)))) + + // Newly submitted taskSet for the reduce phase should not contain the + // running task. + assert(taskSets(3).tasks.size == 2) + + // Fail the running tasks in taskSets(1) with ExceptionFailure. Since the + // taskSet(1) is in zombie state now, the DagScheduler should rerun the + // failed task. + + val exceptionFailure = new ExceptionFailure( + new SparkException("fondue?"), Seq.empty) + + runEvent(makeCompletionEvent(taskSets(1).tasks(2), exceptionFailure, "result")) + + // Finish the taskSet(3) successfully + runEvent(makeCompletionEvent( + taskSets(3).tasks(0), Success, 42, + Seq.empty, createFakeTaskInfo())) + + runEvent(makeCompletionEvent( + taskSets(3).tasks(1), Success, 42, + Seq.empty, createFakeTaskInfo())) + + // Make sure that we resubmit the failed reduce task + Thread.sleep(1000) + assert(taskSets(4).tasks.size == 1) + } + test("trivial shuffle with multiple fetch failures") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) @@ -927,10 +1140,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou /** * This tests the case where a late FetchFailed comes in after the map stage has finished getting - * retried and a new reduce stage starts running. + * retried and a new reduce stage starts running. We make sure that the late FetchFailure is not + * ignored and map stage is resubmitted. */ - test("extremely late fetch failures don't cause multiple concurrent attempts for " + - "the same stage") { + test("extremely late fetch failures should cause rerun of the map stage") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId @@ -983,8 +1196,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // the FetchFailed should have been ignored runEvent(ResubmitFailedStages) - // The FetchFailed from the original reduce stage should be ignored. - assert(countSubmittedMapStageAttempts() === 2) + // The FetchFailed from the original reduce stage should not be ignored. + assert(countSubmittedMapStageAttempts() === 3) } test("task events always posted in speculation / when stage is killed") { @@ -1193,11 +1406,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // now here is where things get tricky : we will now have a task set representing // the second attempt for stage 1, but we *also* have some tasks for the first attempt for - // stage 1 still going + // stage 1 still going, so we make sure that we don't resubmit the already running tasks. val stage1Resubmit = taskSets(3) assert(stage1Resubmit.stageId == 1) assert(stage1Resubmit.stageAttemptId === 1) - assert(stage1Resubmit.tasks.length === 3) + assert(stage1Resubmit.tasks.length === 1) // we'll have some tasks finish from the first attempt, and some finish from the second attempt, // so that we actually have all stage outputs, though no attempt has completed all its @@ -1207,7 +1420,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou Success, makeMapStatus("hostC", reduceRdd.partitions.length))) runEvent(makeCompletionEvent( - taskSets(3).tasks(1), + taskSets(1).tasks(1), Success, makeMapStatus("hostC", reduceRdd.partitions.length))) // late task finish from the first attempt @@ -2021,4 +2234,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo) } + private def makeBeginEvent( + task: Task[_], + taskInfo: TaskInfo): BeginEvent = { + BeginEvent(task, taskInfo) + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 4fe705b201ec..45991736232f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -21,7 +21,8 @@ import org.apache.spark.TaskContext class FakeTask( stageId: Int, - prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) { + partitionId: Int = 0, + prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, partitionId) { override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs } @@ -40,7 +41,7 @@ object FakeTask { throw new IllegalArgumentException("Wrong number of task locations") } val tasks = Array.tabulate[Task[_]](numTasks) { i => - new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) + new FakeTask(0, i, if (prefLocs.size != 0) prefLocs(i) else Nil) } new TaskSet(tasks, 0, stageAttemptId, 0, null) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala index 467796d7c24b..54d4302b5a9e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala @@ -30,7 +30,7 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { def createTaskSetManager(stageId: Int, numTasks: Int, taskScheduler: TaskSchedulerImpl) : TaskSetManager = { val tasks = Array.tabulate[Task[_]](numTasks) { i => - new FakeTask(i, Nil) + new FakeTask(i, 0, Nil) } new TaskSetManager(taskScheduler, new TaskSet(tasks, stageId, 0, 0, null), 0) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 9b7b945bf367..c4c8d16991e1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -20,19 +20,25 @@ package org.apache.spark.scheduler import java.util.Random import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - +import scala.collection.mutable.{ArrayBuffer, HashSet} import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.util.{AccumulatorV2, ManualClock} +import org.apache.spark.storage.BlockManagerId class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) extends DAGScheduler(sc) { + val abortedPartitions = new HashSet[Int] override def taskStarted(task: Task[_], taskInfo: TaskInfo) { taskScheduler.startedTasks += taskInfo.index } + override def tasksAborted(stageId: Int, tasks: Seq[Task[_]]): Unit = { + for (task <- tasks) { + abortedPartitions += task.partitionId + } + } override def taskEnded( task: Task[_], reason: TaskEndReason, @@ -393,6 +399,48 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } } + test("pending tasks should be aborted after first fetch failure") { + val rescheduleDelay = 300L + val conf = new SparkConf(). + set("spark.scheduler.executorTaskBlacklistTime", rescheduleDelay.toString). + // don't wait to jump locality levels in this test + set("spark.locality.wait", "0") + + sc = new SparkContext("local", "test", conf) + // two executors on same host, one on different. + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + ("exec1.1", "host1"), ("exec2", "host2")) + // affinity to exec1 on host1 - which we will fail. + val taskSet = FakeTask.createTaskSet(4) + val clock = new ManualClock + val manager = new TaskSetManager(sched, taskSet, 4, clock) + + val offerResult1 = manager.resourceOffer("exec1", "host1", ANY) + assert(offerResult1.isDefined, "Expect resource offer to return a task") + + assert(offerResult1.get.index === 0) + assert(offerResult1.get.executorId === "exec1") + + val offerResult2 = manager.resourceOffer("exec2", "host2", ANY) + assert(offerResult2.isDefined, "Expect resource offer to return a task") + + assert(offerResult2.get.index === 1) + assert(offerResult2.get.executorId === "exec2") + // At this point, we have 2 tasks running and 2 pending. First fetch failure should + // abort all the pending tasks but the running tasks should not be aborted. + manager.handleFailedTask(offerResult1.get.taskId, TaskState.FINISHED, + FetchFailed(BlockManagerId("exec-host2", "host2", 12345), 0, 0, 0, "ignored")) + val dagScheduler = sched.dagScheduler.asInstanceOf[FakeDAGScheduler] + assert(dagScheduler.abortedPartitions.size === 2) + + dagScheduler.abortedPartitions.clear() + // Second fetch failure should not notify the DagScheduler of the aborted tasks. + + manager.handleFailedTask(offerResult1.get.taskId, TaskState.FINISHED, + FetchFailed(BlockManagerId("exec-host2", "host2", 12345), 0, 0, 0, "ignored")) + assert(dagScheduler.abortedPartitions.size === 0) + } + test("executors should be blacklisted after task failure, in spite of locality preferences") { val rescheduleDelay = 300L val conf = new SparkConf().