Skip to content

[SPARK-51272][CORE] Aborting instead of continuing partially completed indeterminate result stage at ResubmitFailedStages #50630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
144 changes: 90 additions & 54 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,26 @@ private[spark] class DAGScheduler(
// `findMissingPartitions()` returns all partitions every time.
stage match {
case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable =>
// already executed at least once
if (sms.getNextAttemptId > 0) {
// While we previously validated possible rollbacks during the handling of a FetchFailure,
// where we were fetching from an indeterminate source map stages, this later check
// covers additional cases like recalculating an indeterminate stage after an executor
// loss. Moreover, because this check occurs later in the process, if a result stage task
// has successfully completed, we can detect this and abort the job, as rolling back a
// result stage is not possible.
val stagesToRollback = collectSucceedingStages(sms)
abortStageWithInvalidRollBack(stagesToRollback)
// stages which cannot be rolled back were aborted which leads to removing the
// the dependant job(s) from the active jobs set
val numActiveJobsWithStageAfterRollback =
activeJobs.count(job => stagesToRollback.contains(job.finalStage))
if (numActiveJobsWithStageAfterRollback == 0) {
logInfo(log"All jobs depending on the indeterminate stage " +
log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.")
return
}
}
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
sms.shuffleDep.newShuffleMergeState()
case _ =>
Expand Down Expand Up @@ -2129,60 +2149,8 @@ private[spark] class DAGScheduler(
// guaranteed to be determinate, so the input data of the reducers will not change
// even if the map tasks are re-tried.
if (mapStage.isIndeterminate) {
// It's a little tricky to find all the succeeding stages of `mapStage`, because
// each stage only know its parents not children. Here we traverse the stages from
// the leaf nodes (the result stages of active jobs), and rollback all the stages
// in the stage chains that connect to the `mapStage`. To speed up the stage
// traversing, we collect the stages to rollback first. If a stage needs to
// rollback, all its succeeding stages need to rollback to.
val stagesToRollback = HashSet[Stage](mapStage)

def collectStagesToRollback(stageChain: List[Stage]): Unit = {
if (stagesToRollback.contains(stageChain.head)) {
stageChain.drop(1).foreach(s => stagesToRollback += s)
} else {
stageChain.head.parents.foreach { s =>
collectStagesToRollback(s :: stageChain)
}
}
}

def generateErrorMessage(stage: Stage): String = {
"A shuffle map stage with indeterminate output was failed and retried. " +
s"However, Spark cannot rollback the $stage to re-process the input data, " +
"and has to fail this job. Please eliminate the indeterminacy by " +
"checkpointing the RDD before repartition and try again."
}

activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil))

// The stages will be rolled back after checking
val rollingBackStages = HashSet[Stage](mapStage)
stagesToRollback.foreach {
case mapStage: ShuffleMapStage =>
val numMissingPartitions = mapStage.findMissingPartitions().length
if (numMissingPartitions < mapStage.numTasks) {
if (sc.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) {
val reason = "A shuffle map stage with indeterminate output was failed " +
"and retried. However, Spark can only do this while using the new " +
"shuffle block fetching protocol. Please check the config " +
"'spark.shuffle.useOldFetchProtocol', see more detail in " +
"SPARK-27665 and SPARK-25341."
abortStage(mapStage, reason, None)
} else {
rollingBackStages += mapStage
}
}

case resultStage: ResultStage if resultStage.activeJob.isDefined =>
val numMissingPartitions = resultStage.findMissingPartitions().length
if (numMissingPartitions < resultStage.numTasks) {
// TODO: support to rollback result tasks.
abortStage(resultStage, generateErrorMessage(resultStage), None)
}

case _ =>
}
val stagesToRollback = collectSucceedingStages(mapStage)
val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback)
logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output was failed, " +
log"we will roll back and rerun below stages which include itself and all its " +
log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}")
Expand Down Expand Up @@ -2346,6 +2314,74 @@ private[spark] class DAGScheduler(
}
}

private def collectSucceedingStages(mapStage: ShuffleMapStage): HashSet[Stage] = {
// TODO: perhaps materialize this if we are going to compute it often enough ?
// It's a little tricky to find all the succeeding stages of `mapStage`, because
// each stage only know its parents not children. Here we traverse the stages from
// the leaf nodes (the result stages of active jobs), and rollback all the stages
// in the stage chains that connect to the `mapStage`. To speed up the stage
// traversing, we collect the stages to rollback first. If a stage needs to
// rollback, all its succeeding stages need to rollback to.
val succeedingStages = HashSet[Stage](mapStage)

def collectSucceedingStagesInternal(stageChain: List[Stage]): Unit = {
if (succeedingStages.contains(stageChain.head)) {
stageChain.drop(1).foreach(s => succeedingStages += s)
} else {
stageChain.head.parents.foreach { s =>
collectSucceedingStagesInternal(s :: stageChain)
}
}
}
activeJobs.foreach(job => collectSucceedingStagesInternal(job.finalStage :: Nil))
succeedingStages
}

/**
* Abort stages where roll back is requested but cannot be completed.
*
* @param stagesToRollback stages to roll back
* @return Shuffle map stages which need and can be rolled back
*/
private def abortStageWithInvalidRollBack(stagesToRollback: HashSet[Stage]): HashSet[Stage] = {

def generateErrorMessage(stage: Stage): String = {
"A shuffle map stage with indeterminate output was failed and retried. " +
s"However, Spark cannot rollback the $stage to re-process the input data, " +
"and has to fail this job. Please eliminate the indeterminacy by " +
"checkpointing the RDD before repartition and try again."
}

// The stages will be rolled back after checking
val rollingBackStages = HashSet[Stage]()
stagesToRollback.foreach {
case mapStage: ShuffleMapStage =>
if (mapStage.numAvailableOutputs > 0) {
if (sc.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) {
val reason = "A shuffle map stage with indeterminate output was failed " +
"and retried. However, Spark can only do this while using the new " +
"shuffle block fetching protocol. Please check the config " +
"'spark.shuffle.useOldFetchProtocol', see more detail in " +
"SPARK-27665 and SPARK-25341."
abortStage(mapStage, reason, None)
} else {
rollingBackStages += mapStage
}
}

case resultStage: ResultStage if resultStage.activeJob.isDefined =>
val numMissingPartitions = resultStage.findMissingPartitions().length
if (numMissingPartitions < resultStage.numTasks) {
// TODO: support to rollback result tasks.
abortStage(resultStage, generateErrorMessage(resultStage), None)
}

case _ =>
}

rollingBackStages
}

/**
* Whether executor is decommissioning or decommissioned.
* Return true when:
Expand Down
Loading