Skip to content

Commit eb427f7

Browse files
committed
address comments
1 parent 9a7b053 commit eb427f7

File tree

4 files changed

+24
-26
lines changed

4 files changed

+24
-26
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
157157

158158
def enqueuePartitionCompletionNotification(stageId: Int, partitionId: Int): Unit = {
159159
getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions {
160-
scheduler.markPartitionCompleted(stageId, partitionId)
160+
scheduler.handlePartitionCompleted(stageId, partitionId)
161161
})
162162
}
163163

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,23 @@ private[spark] class TaskSchedulerImpl(
641641
}
642642
}
643643

644+
/**
645+
* Marks the task has completed in the active TaskSetManager for the given stage.
646+
*
647+
* After stage failure and retry, there may be multiple TaskSetManagers for the stage.
648+
* If an earlier zombie attempt of a stage completes a task, we can ask the later active attempt
649+
* to skip submitting and running the task for the same partition, to save resource. That also
650+
* means that a task completion from an earlier zombie attempt can lead to the entire stage
651+
* getting marked as successful.
652+
*/
653+
private[scheduler] def handlePartitionCompleted(
654+
stageId: Int,
655+
partitionId: Int) = synchronized {
656+
taskSetsByStageIdAndAttempt.get(stageId).foreach(_.values.filter(!_.isZombie).foreach { tsm =>
657+
tsm.markPartitionCompleted(partitionId)
658+
})
659+
}
660+
644661
def error(message: String) {
645662
synchronized {
646663
if (taskSetsByStageIdAndAttempt.nonEmpty) {
@@ -872,23 +889,6 @@ private[spark] class TaskSchedulerImpl(
872889
manager
873890
}
874891
}
875-
876-
/**
877-
* Marks the task has completed in the active TaskSetManager for the given stage.
878-
*
879-
* After stage failure and retry, there may be multiple TaskSetManagers for the stage.
880-
* If an earlier zombie attempt of a stage completes a task, we can ask the later active attempt
881-
* to skip submitting and running the task for the same partition, to save resource. That also
882-
* means that a task completion from an earlier zombie attempt can lead to the entire stage
883-
* getting marked as successful.
884-
*/
885-
private[scheduler] def markPartitionCompleted(
886-
stageId: Int,
887-
partitionId: Int) = {
888-
taskSetsByStageIdAndAttempt.get(stageId).foreach(_.values.filter(!_.isZombie).foreach { tsm =>
889-
tsm.markPartitionCompleted(partitionId)
890-
})
891-
}
892892
}
893893

894894

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,10 @@ private[spark] class TaskSetManager(
819819
private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
820820
partitionToIndex.get(partitionId).foreach { index =>
821821
if (!successful(index)) {
822+
if (speculationEnabled && !isZombie) {
823+
// The task is skipped, its duration should be 0.
824+
successfulTaskDurations.insert(0)
825+
}
822826
tasksSuccessful += 1
823827
successful(index) = true
824828
if (tasksSuccessful == numTasks) {
@@ -1035,11 +1039,7 @@ private[spark] class TaskSetManager(
10351039
val minFinishedForSpeculation = (speculationQuantile * numTasks).floor.toInt
10361040
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
10371041

1038-
// It's possible that a task is marked as completed by the scheduler, then the size of
1039-
// `successfulTaskDurations` may not equal to `tasksSuccessful`. Here we should only count the
1040-
// tasks that are submitted by this `TaskSetManager` and are completed successfully.
1041-
val numSuccessfulTasks = successfulTaskDurations.size()
1042-
if (numSuccessfulTasks >= minFinishedForSpeculation && numSuccessfulTasks > 0) {
1042+
if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
10431043
val time = clock.getTimeMillis()
10441044
val medianDuration = successfulTaskDurations.median
10451045
val threshold = max(speculationMultiplier * medianDuration, minTimeToSpeculation)

core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,9 +1387,6 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
13871387
sched.setDAGScheduler(dagScheduler)
13881388

13891389
val taskSet = FakeTask.createTaskSet(10)
1390-
val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task =>
1391-
task.metrics.internalAccums
1392-
}
13931390

13941391
sched.submitTasks(taskSet)
13951392
sched.resourceOffers(
@@ -1398,6 +1395,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
13981395
val taskSetManager = sched.taskSetManagerForAttempt(0, 0).get
13991396
assert(taskSetManager.runningTasks === 8)
14001397
taskSetManager.markPartitionCompleted(8)
1398+
assert(!taskSetManager.successfulTaskDurations.isEmpty())
14011399
taskSetManager.checkSpeculatableTasks(0)
14021400
}
14031401

0 commit comments

Comments
 (0)