Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,17 @@ private[spark] class MapOutputTrackerMaster(
}
}

/** Unregister all map output information of the given shuffle. */
def unregisterAllMapOutput(shuffleId: Int) {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.removeOutputsByFilter(x => true)
incrementEpoch()
case None =>
throw new SparkException("unregisterAllMapOutput called for nonexistent shuffle ID")
}
}

/** Unregister shuffle data */
def unregisterShuffle(shuffleId: Int) {
shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import org.apache.spark.{Partition, TaskContext}
private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
var prev: RDD[T],
f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
preservesPartitioning: Boolean = false)
preservesPartitioning: Boolean = false,
retryOnAllPartitionsOnFailure: Boolean = false)
extends RDD[U](prev) {

override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
Expand All @@ -41,4 +42,6 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
super.clearDependencies()
prev = null
}

override def recomputeAllPartitionsOnFailure(): Boolean = retryOnAllPartitionsOnFailure
}
36 changes: 32 additions & 4 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,10 @@ abstract class RDD[T: ClassTag](
/** Distributes elements evenly across output partitions, starting from a random partition. */
val distributePartition = (index: Int, items: Iterator[T]) => {
var position = new Random(hashing.byteswap32(index)).nextInt(numPartitions)
// TODO Enable insert a local sort before shuffle to make input data sequence
// deterministic, to avoid retry on all partitions on FetchFailure. However, performing
// a local sort before shuffle may increase the execution time of repartition()
// significantly (For some large input data can cost 3x ~ 5x time).
items.map { t =>
// Note that the hash code of the key will just be the key itself. The HashPartitioner
// will mod it with the number of total partitions.
Expand All @@ -462,8 +466,11 @@ abstract class RDD[T: ClassTag](

// include a shuffle step so that our upstream tasks are still distributed
new CoalescedRDD(
new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition),
new HashPartitioner(numPartitions)),
new ShuffledRDD[Int, T, T](
mapPartitionsWithIndexInternal(
distributePartition,
retryOnAllPartitionsOnFailure = true),
new HashPartitioner(numPartitions)),
numPartitions,
partitionCoalescer).values
} else {
Expand Down Expand Up @@ -809,14 +816,19 @@ abstract class RDD[T: ClassTag](
* @param preservesPartitioning indicates whether the input function preserves the partitioner,
* which should be `false` unless this is a pair RDD and the input function doesn't modify
* the keys.
*
* @param retryOnAllPartitionsOnFailure indicates whether to recompute on all the partitions on
* failure recovery, which should be `false` unless the output is repartitioned.
*/
private[spark] def mapPartitionsWithIndexInternal[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = withScope {
preservesPartitioning: Boolean = false,
retryOnAllPartitionsOnFailure: Boolean = false): RDD[U] = withScope {
new MapPartitionsRDD(
this,
(context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter),
preservesPartitioning)
preservesPartitioning,
retryOnAllPartitionsOnFailure)
}

/**
Expand Down Expand Up @@ -1839,6 +1851,22 @@ abstract class RDD[T: ClassTag](
def toJavaRDD() : JavaRDD[T] = {
new JavaRDD(this)(elementClassTag)
}

/**
* Whether or not the RDD is required to recompute all partitions on FetchFailure. Repartition on
* an RDD performs in a round-robin manner, thus there may be data correctness issue if only a
* sub-set of partitions are recomputed on FetchFailure and the input data sequence is not
* deterministic. Please refer to SPARK-23207 and SPARK-23243 for related discussion.
*
* Ideally we don't need to recompute all partitions on FetchFailure if the result sequence of an
* RDD is deterministic, but various sources (that out of control of Spark) may lead to
* non-determine result sequence(e.g. read from external data source / different spill and merge
* pattern under memory pressure), and we cannot bear the performance degradation by inserting a
* local sort before shuffle(can cost 3x ~ 5x time for repartition()), and the data type of an
* RDD may even be not sortable. Due to the above reason, we make a compromise to just require to
* recompute all partitions on FetchFailure if repartition operation is called on an RDD.
*/
private[spark] def recomputeAllPartitionsOnFailure(): Boolean = false
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1323,17 +1323,17 @@ class DAGScheduler(
}

case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) =>
val failedStage = stageIdToStage(task.stageId)
val failedStage = stage
val mapStage = shuffleIdToMapStage(shuffleId)

if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) {
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
s"(attempt ${failedStage.latestInfo.attemptNumber}) running")
} else {
failedStage.fetchFailedAttemptIds.add(task.stageAttemptId)
failedStage.failedAttemptIds.add(task.stageAttemptId)
val shouldAbortStage =
failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts ||
failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
disallowStageRetryForTest

// It is likely that we receive multiple FetchFailed for a single stage (because we have
Expand Down Expand Up @@ -1386,8 +1386,12 @@ class DAGScheduler(
)
}
}
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {

if (mapStage.rdd.recomputeAllPartitionsOnFailure()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, what if we have a map after repartition? then the root RDD will return false on recomputeAllPartitionsOnFailure

// Mark all the map as broken in the map stage, to ensure recompute all the partitions
// on resubmitted stage attempt.
mapOutputTracker.unregisterAllMapOutput(shuffleId)
} else if (mapId != -1) {
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
}

Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/org/apache/spark/scheduler/Stage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ private[scheduler] abstract class Stage(
private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId)

/**
* Set of stage attempt IDs that have failed with a FetchFailure. We keep track of these
* failures in order to avoid endless retries if a stage keeps failing with a FetchFailure.
* Set of stage attempt IDs that have failed. We keep track of these failures in order to avoid
* endless retries if a stage keeps failing.
* We keep track of each attempt ID that has failed to avoid recording duplicate failures if
* multiple tasks from the same stage attempt fail (SPARK-5945).
*/
val fetchFailedAttemptIds = new HashSet[Int]
val failedAttemptIds = new HashSet[Int]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why rename it? we only increase it on fetch failure, don't we?


private[scheduler] def clearFailures() : Unit = {
fetchFailedAttemptIds.clear()
failedAttemptIds.clear()
}

/** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */
Expand Down