diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 73646051f264..aeec2876a0f0 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -434,6 +434,17 @@ private[spark] class MapOutputTrackerMaster( } } + /** Unregister all map output information of the given shuffle. */ + def unregisterAllMapOutput(shuffleId: Int) { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.removeOutputsByFilter(x => true) + incrementEpoch() + case None => + throw new SparkException("unregisterAllMapOutput called for nonexistent shuffle ID") + } + } + /** Unregister shuffle data */ def unregisterShuffle(shuffleId: Int) { shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index e4587c96eae1..8de3498aea48 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -27,7 +27,8 @@ import org.apache.spark.{Partition, TaskContext} private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( var prev: RDD[T], f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) - preservesPartitioning: Boolean = false) + preservesPartitioning: Boolean = false, + retryOnAllPartitionsOnFailure: Boolean = false) extends RDD[U](prev) { override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None @@ -41,4 +42,6 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( super.clearDependencies() prev = null } + + override def recomputeAllPartitionsOnFailure(): Boolean = retryOnAllPartitionsOnFailure } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0574abdca32a..62b2b6e36074 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -452,6 +452,10 @@ abstract class RDD[T: ClassTag]( /** Distributes elements evenly across output partitions, starting from a random partition. */ val distributePartition = (index: Int, items: Iterator[T]) => { var position = new Random(hashing.byteswap32(index)).nextInt(numPartitions) + // TODO Enable insert a local sort before shuffle to make input data sequence + // deterministic, to avoid retry on all partitions on FetchFailure. However, performing + // a local sort before shuffle may increase the execution time of repartition() + // significantly (For some large input data can cost 3x ~ 5x time). items.map { t => // Note that the hash code of the key will just be the key itself. The HashPartitioner // will mod it with the number of total partitions. @@ -462,8 +466,11 @@ abstract class RDD[T: ClassTag]( // include a shuffle step so that our upstream tasks are still distributed new CoalescedRDD( - new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition), - new HashPartitioner(numPartitions)), + new ShuffledRDD[Int, T, T]( + mapPartitionsWithIndexInternal( + distributePartition, + retryOnAllPartitionsOnFailure = true), + new HashPartitioner(numPartitions)), numPartitions, partitionCoalescer).values } else { @@ -809,14 +816,19 @@ abstract class RDD[T: ClassTag]( * @param preservesPartitioning indicates whether the input function preserves the partitioner, * which should be `false` unless this is a pair RDD and the input function doesn't modify * the keys. + * + * @param retryOnAllPartitionsOnFailure indicates whether to recompute on all the partitions on + * failure recovery, which should be `false` unless the output is repartitioned. */ private[spark] def mapPartitionsWithIndexInternal[U: ClassTag]( f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = withScope { + preservesPartitioning: Boolean = false, + retryOnAllPartitionsOnFailure: Boolean = false): RDD[U] = withScope { new MapPartitionsRDD( this, (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter), - preservesPartitioning) + preservesPartitioning, + retryOnAllPartitionsOnFailure) } /** @@ -1839,6 +1851,22 @@ abstract class RDD[T: ClassTag]( def toJavaRDD() : JavaRDD[T] = { new JavaRDD(this)(elementClassTag) } + + /** + * Whether or not the RDD is required to recompute all partitions on FetchFailure. Repartition on + * an RDD performs in a round-robin manner, thus there may be data correctness issue if only a + * sub-set of partitions are recomputed on FetchFailure and the input data sequence is not + * deterministic. Please refer to SPARK-23207 and SPARK-23243 for related discussion. + * + * Ideally we don't need to recompute all partitions on FetchFailure if the result sequence of an + * RDD is deterministic, but various sources (that out of control of Spark) may lead to + * non-determine result sequence(e.g. read from external data source / different spill and merge + * pattern under memory pressure), and we cannot bear the performance degradation by inserting a + * local sort before shuffle(can cost 3x ~ 5x time for repartition()), and the data type of an + * RDD may even be not sortable. Due to the above reason, we make a compromise to just require to + * recompute all partitions on FetchFailure if repartition operation is called on an RDD. + */ + private[spark] def recomputeAllPartitionsOnFailure(): Boolean = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f74425d73b39..3fab97322ffa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1323,7 +1323,7 @@ class DAGScheduler( } case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) => - val failedStage = stageIdToStage(task.stageId) + val failedStage = stage val mapStage = shuffleIdToMapStage(shuffleId) if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { @@ -1331,9 +1331,9 @@ class DAGScheduler( s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") } else { - failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) + failedStage.failedAttemptIds.add(task.stageAttemptId) val shouldAbortStage = - failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || disallowStageRetryForTest // It is likely that we receive multiple FetchFailed for a single stage (because we have @@ -1386,8 +1386,12 @@ class DAGScheduler( ) } } - // Mark the map whose fetch failed as broken in the map stage - if (mapId != -1) { + + if (mapStage.rdd.recomputeAllPartitionsOnFailure()) { + // Mark all the map as broken in the map stage, to ensure recompute all the partitions + // on resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(shuffleId) + } else if (mapId != -1) { mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 290fd073caf2..26cca334d3bd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -82,15 +82,15 @@ private[scheduler] abstract class Stage( private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId) /** - * Set of stage attempt IDs that have failed with a FetchFailure. We keep track of these - * failures in order to avoid endless retries if a stage keeps failing with a FetchFailure. + * Set of stage attempt IDs that have failed. We keep track of these failures in order to avoid + * endless retries if a stage keeps failing. * We keep track of each attempt ID that has failed to avoid recording duplicate failures if * multiple tasks from the same stage attempt fail (SPARK-5945). */ - val fetchFailedAttemptIds = new HashSet[Int] + val failedAttemptIds = new HashSet[Int] private[scheduler] def clearFailures() : Unit = { - fetchFailedAttemptIds.clear() + failedAttemptIds.clear() } /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */