Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,13 @@ class DAGScheduler(
eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception))
}

/**
* Called by the TaskSetManager when a set of tasks are aborted due to fetch failure.
*/
def tasksAborted(stageId: Int, tasks: Seq[Task[_]]): Unit = {
eventProcessLoop.post(TasksAborted(stageId, tasks))
}

private[scheduler]
def getCacheLocs(rdd: RDD[_]): IndexedSeq[Seq[TaskLocation]] = cacheLocs.synchronized {
// Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times
Expand Down Expand Up @@ -810,6 +817,16 @@ class DAGScheduler(
submitWaitingStages()
}

private[scheduler] def handleTasksAborted(
stageId: Int,
tasks: Seq[Task[_]]): Unit = {
for (stage <- stageIdToStage.get(stageId)) {
for (task <- tasks) {
stage.pendingPartitions -= task.partitionId
}
}
}

private[scheduler] def cleanUpAfterSchedulerStop() {
for (job <- activeJobs) {
val error =
Expand Down Expand Up @@ -941,14 +958,22 @@ class DAGScheduler(
}
}

/** Called when stage's parents are available and we can now do its task. */
/**
* Called when stage's parents are available and we can now run its task.
* This only submits the partitions which are missing and have not been
* submitted to the lower-level scheduler for execution.
*/
private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingPartitions.clear()

// First figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()
val missingPartitions = stage.findMissingPartitions()
val partitionsToCompute =
missingPartitions.filter(id => !stage.pendingPartitions.contains(id))
stage.pendingPartitions ++= partitionsToCompute

if (partitionsToCompute.isEmpty) {
return
}

// Use the scheduling pool, job group, description, etc. from an ActiveJob associated
// with this Stage
Expand All @@ -971,7 +996,6 @@ class DAGScheduler(
case s: ShuffleMapStage =>
partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
case s: ResultStage =>
val job = s.activeJob.get
partitionsToCompute.map { id =>
val p = s.partitions(id)
(id, getPreferredLocs(stage.rdd, p))
Expand Down Expand Up @@ -1051,7 +1075,6 @@ class DAGScheduler(

if (tasks.size > 0) {
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
stage.pendingPartitions ++= tasks.map(_.partitionId)
logDebug("New pending partitions: " + stage.pendingPartitions)
taskScheduler.submitTasks(new TaskSet(
tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties))
Expand Down Expand Up @@ -1153,9 +1176,9 @@ class DAGScheduler(
}

val stage = stageIdToStage(task.stageId)
stage.pendingPartitions -= task.partitionId
event.reason match {
case Success =>
stage.pendingPartitions -= task.partitionId
task match {
case rt: ResultTask[_, _] =>
// Cast to ResultStage here because it's part of the ResultTask
Expand All @@ -1173,6 +1196,12 @@ class DAGScheduler(
cleanupStateForJobAndIndependentStages(job)
listenerBus.post(
SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded))
} else if (resultStage.pendingPartitions.isEmpty) {
logInfo("Resubmitting " + resultStage + " (" + resultStage.name +
") because some of its tasks had failed: " +
resultStage.findMissingPartitions().mkString(", "))
markStageAsFinished(resultStage)
submitStage(resultStage)
}

// taskSucceeded runs some user code that might throw an exception. Make sure
Expand Down Expand Up @@ -1250,11 +1279,6 @@ class DAGScheduler(
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleToMapStage(shuffleId)

if (failedStage.latestInfo.attemptId != task.stageAttemptId) {
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
s"(attempt ID ${failedStage.latestInfo.attemptId}) running")
} else {
// It is likely that we receive multiple FetchFailed for a single stage (because we have
// multiple tasks running concurrently on different executors). In that case, it is
// possible the fetch failure has already been handled by the scheduler.
Expand All @@ -1278,7 +1302,6 @@ class DAGScheduler(
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
messageScheduler.schedule(new Runnable {
Expand All @@ -1297,7 +1320,6 @@ class DAGScheduler(
if (bmAddress != null) {
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
}
}

case commitDenied: TaskCommitDenied =>
// Do nothing here, left up to the TaskScheduler to decide how to handle denied commits
Expand Down Expand Up @@ -1656,6 +1678,9 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case TaskSetFailed(taskSet, reason, exception) =>
dagScheduler.handleTaskSetFailed(taskSet, reason, exception)

case TasksAborted(stageId, tasks) =>
dagScheduler.handleTasksAborted(stageId, tasks)

case ResubmitFailedStages =>
dagScheduler.resubmitFailedStages()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,7 @@ private[scheduler]
case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable])
extends DAGSchedulerEvent

private[scheduler]
case class TasksAborted(stageId: Int, tasks: Seq[Task[_]]) extends DAGSchedulerEvent

private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent
5 changes: 5 additions & 0 deletions core/src/main/scala/org/apache/spark/scheduler/Stage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.scheduler

import scala.collection.mutable
import scala.collection.mutable.HashSet

import org.apache.spark._
Expand Down Expand Up @@ -68,6 +69,10 @@ private[scheduler] abstract class Stage(
/** Set of jobs that this stage belongs to. */
val jobIds = new HashSet[Int]

/**
* Set of partitions which have been submitted to the lower-level scheduler and
* they should not be resubmitted when rerun of the stage.
*/
val pendingPartitions = new HashSet[Int]

/** The ID to use for the next new attempt for this stage. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ private[spark] class TaskSetManager(
// the zombie state once at least one attempt of each task has completed successfully, or if the
// task set is aborted (for example, because it was killed). TaskSetManagers remain in the zombie
// state until all tasks have finished running; we keep TaskSetManagers that are in the zombie
// state in order to continue to track and account for the running tasks.
// TODO: We should kill any running task attempts when the task set manager becomes a zombie.
// state in order to continue to track and account for the running tasks. The tasks running in the
// zombie TaskSetManagers are not rerun by the DagScheduler unless they fail.
var isZombie = false

// Set of pending tasks for each executor. These collections are actually
Expand Down Expand Up @@ -652,13 +652,30 @@ private[spark] class TaskSetManager(
reason.asInstanceOf[TaskFailedReason].toErrorString
val failureException: Option[Throwable] = reason match {
case fetchFailed: FetchFailed =>
if (!isZombie) {
// Only for the first occurrence of the fetch failure, get the list
// of all non-running and non-successful tasks and notify the
// DagScheduler of their abortion so that they can be rescheduled in retry
// of the stage. Note that this does not include the fetch failed tasks,
// because that is separately handled by the DagScheduler.
val abortedTasks = new ArrayBuffer[Task[_]]
for (i <- 0 until numTasks) {
if (i != index && !successful(i) && copiesRunning(i) == 0) {
abortedTasks += taskSet.tasks(i)
}
}
if (!abortedTasks.isEmpty) {
sched.dagScheduler.tasksAborted(abortedTasks(0).stageId, abortedTasks)
}
isZombie = true
}

logWarning(failureReason)
if (!successful(index)) {
successful(index) = true
tasksSuccessful += 1
}
// Not adding to failed executors for FetchFailed.
isZombie = true
None

case ef: ExceptionFailure =>
Expand Down
Loading