Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,11 @@ private[spark] class TaskSchedulerImpl(

// TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]
/** Mapping from stage ID to the TaskSetManager for the most recent attempt for the stage.*/
val stageIdToLatestTaskSet = new HashMap[Int, TaskSetManager]
/** Mapping from task ID to the TaskSetManager for that task. */
val taskIdToTaskSet = new HashMap[Long, TaskSetManager]

val taskIdToStageIdAndAttempt = new HashMap[Long, (Int, Int)]
val taskIdToExecutorId = new HashMap[Long, String]

@volatile private var hasReceivedTask = false
Expand Down Expand Up @@ -162,17 +164,14 @@ private[spark] class TaskSchedulerImpl(
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = createTaskSetManager(taskSet, maxTaskFailures)
val stage = taskSet.stageId
val stageTaskSets =
taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
stageTaskSets(taskSet.stageAttemptId) = manager
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
ts.taskSet != taskSet && !ts.isZombie
}
if (conflictingTaskSet) {
throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
val stageId = taskSet.stageId
// Make sure there aren't already any running attempts for the stageId.
stageIdToLatestTaskSet.get(taskSet.stageId).filter(!_.isZombie).foreach { taskSetManager =>
throw new IllegalStateException(
s"Already one active task set for stageId $stageId (attempt : " +
s"${taskSetManager.taskSet.stageAttemptId})")
}
stageIdToLatestTaskSet.put(taskSet.stageId, manager)
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

if (!isLocal && !hasReceivedTask) {
Expand Down Expand Up @@ -202,21 +201,19 @@ private[spark] class TaskSchedulerImpl(

override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
attempts.foreach { case (_, tsm) =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task and then abort
// the stage.
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
// simply abort the stage.
tsm.runningTasksSet.foreach { tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId, interruptThread)
}
tsm.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
stageIdToLatestTaskSet.get(stageId).foreach { taskSetManager =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task and then abort
// the stage.
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
// simply abort the stage.
taskSetManager.runningTasksSet.foreach { tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId, interruptThread)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a change in behavior here. If you've got a zombie attempt with some tasks still running, but you also have a new active attempt, then when cancelTasks gets called, you'll only kill the tasks for the active attempt. The zombie attempt will not longer have its tasks cancelled.

I know this is a corner case, but it just means the behavior will be even more confusing when we do encounter it ...

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so one way around this would be if cancelled tasks as soon as an attempt was marked zombie (which we should probably do anyway, we have several lingering TODOs for it). But, I'm thinking maybe we can save this for a future improvement?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one more thought: one thing which kinda bugs me now is that TaskSetManager.isZombie is just a var, so a potential bug is if somebody decides to set it back to true. Really it should be private and we have a markZombie() method or something, which would also be where we could create a request to cancel all of the tasks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with all of the above, especially the bit about actually canceling tasks when an attempt is marked zombie, because it's so easy to do.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(and fine to save this for a future cleanup)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also realizing this is yet another untested part of TaskSchedulerImpl ...

}
taskSetManager.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
}
}

Expand All @@ -226,11 +223,11 @@ private[spark] class TaskSchedulerImpl(
* cleaned up.
*/
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
taskSetsForStage -= manager.taskSet.stageAttemptId
if (taskSetsForStage.isEmpty) {
taskSetsByStageIdAndAttempt -= manager.taskSet.stageId
}
// Remove the TaskSetManager from stageIdToRunningTaskSet if it is the currently active task
// manager (there may be a TaskSetManager corresponding to a newer attempt for the stage
// in stageIdToRunningTaskSet, in which case we don't want to remove it).
stageIdToLatestTaskSet.get(manager.stageId).foreach { activeTaskSetManager =>
if (activeTaskSetManager == manager) stageIdToLatestTaskSet -= manager.stageId
}
manager.parent.removeSchedulable(manager)
logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
Expand All @@ -252,8 +249,7 @@ private[spark] class TaskSchedulerImpl(
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
taskIdToStageIdAndAttempt(tid) =
(taskSet.taskSet.stageId, taskSet.taskSet.stageAttemptId)
taskIdToTaskSet(tid) = taskSet
taskIdToExecutorId(tid) = execId
executorsByHost(host) += execId
availableCpus(i) -= CPUS_PER_TASK
Expand Down Expand Up @@ -337,10 +333,10 @@ private[spark] class TaskSchedulerImpl(
failedExecutor = Some(execId)
}
}
taskSetManagerForTask(tid) match {
taskIdToTaskSet.get(tid) match {
case Some(taskSet) =>
if (TaskState.isFinished(state)) {
taskIdToStageIdAndAttempt.remove(tid)
taskIdToTaskSet.remove(tid)
taskIdToExecutorId.remove(tid)
}
if (state == TaskState.FINISHED) {
Expand Down Expand Up @@ -379,12 +375,8 @@ private[spark] class TaskSchedulerImpl(

val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
taskMetrics.flatMap { case (id, metrics) =>
for {
(stageId, stageAttemptId) <- taskIdToStageIdAndAttempt.get(id)
attempts <- taskSetsByStageIdAndAttempt.get(stageId)
taskSetMgr <- attempts.get(stageAttemptId)
} yield {
(id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
taskIdToTaskSet.get(id).map { taskSetMgr =>
(id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
}
}
}
Expand Down Expand Up @@ -417,12 +409,9 @@ private[spark] class TaskSchedulerImpl(

def error(message: String) {
synchronized {
if (taskSetsByStageIdAndAttempt.nonEmpty) {
if (stageIdToLatestTaskSet.nonEmpty) {
// Have each task set throw a SparkException with the error
for {
attempts <- taskSetsByStageIdAndAttempt.values
manager <- attempts.values
} {
for ((taskSetId, manager) <- stageIdToLatestTaskSet) {
try {
manager.abort(message)
} catch {
Expand Down Expand Up @@ -542,24 +531,6 @@ private[spark] class TaskSchedulerImpl(
override def applicationId(): String = backend.applicationId()

override def applicationAttemptId(): Option[String] = backend.applicationAttemptId()

private[scheduler] def taskSetManagerForTask(taskId: Long): Option[TaskSetManager] = {
taskIdToStageIdAndAttempt.get(taskId).flatMap{ case (stageId, stageAttemptId) =>
taskSetManagerForAttempt(stageId, stageAttemptId)
}
}

private[scheduler] def taskSetManagerForAttempt(
stageId: Int,
stageAttemptId: Int): Option[TaskSetManager] = {
for {
attempts <- taskSetsByStageIdAndAttempt.get(stageId)
manager <- attempts.get(stageAttemptId)
} yield {
manager
}
}

}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
for (task <- tasks.flatten) {
val serializedTask = ser.serialize(task)
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
scheduler.taskSetManagerForTask(task.taskId).foreach { taskSet =>
scheduler.taskIdToTaskSet.get(task.taskId).foreach { taskSet =>
try {
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
taskScheduler.submitTasks(taskSet)
val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
assert(1 === taskDescriptions.length)
taskScheduler.taskSetFinished(taskScheduler.stageIdToLatestTaskSet(taskSet.stageId))
taskDescriptions(0).executorId
}
val count = selectedExecutorIds.count(_ == workerOffers(0).executorId)
Expand All @@ -78,6 +79,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
taskScheduler.submitTasks(taskSet)
var taskDescriptions = taskScheduler.resourceOffers(zeroCoreWorkerOffers).flatten
assert(0 === taskDescriptions.length)
taskScheduler.taskSetFinished(taskScheduler.stageIdToLatestTaskSet(taskSet.stageId))

// No tasks should run as we only have 1 core free.
val numFreeCores = 1
Expand All @@ -86,6 +88,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
taskScheduler.submitTasks(taskSet)
taskDescriptions = taskScheduler.resourceOffers(singleCoreWorkerOffers).flatten
assert(0 === taskDescriptions.length)
taskScheduler.taskSetFinished(taskScheduler.stageIdToLatestTaskSet(taskSet.stageId))

// Now change the offers to have 2 cores in one executor and verify if it
// is chosen.
Expand Down Expand Up @@ -123,6 +126,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
// Even if one of the tasks has not-serializable tasks, the other task set should
// still be processed without error
taskScheduler.submitTasks(taskSet)
taskScheduler.taskSetFinished(taskScheduler.stageIdToLatestTaskSet(taskSet.stageId))
taskScheduler.submitTasks(FakeTask.createTaskSet(1))
taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten
assert(taskDescriptions.map(_.executorId) === Seq("executor0"))
Expand All @@ -144,13 +148,11 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) }

// OK to submit multiple if previous attempts are all zombie
taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId)
.get.isZombie = true
taskScheduler.stageIdToLatestTaskSet(attempt1.stageId).isZombie = true
taskScheduler.submitTasks(attempt2)
val attempt3 = FakeTask.createTaskSet(1, 2)
intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) }
taskScheduler.taskSetManagerForAttempt(attempt2.stageId, attempt2.stageAttemptId)
.get.isZombie = true
taskScheduler.stageIdToLatestTaskSet(attempt2.stageId).isZombie = true
taskScheduler.submitTasks(attempt3)
}

Expand All @@ -174,8 +176,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
assert(1 === taskDescriptions.length)

// now mark attempt 1 as a zombie
taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId)
.get.isZombie = true
taskScheduler.stageIdToLatestTaskSet(attempt1.stageId).isZombie = true

// don't schedule anything on another resource offer
val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten
Expand All @@ -188,7 +189,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
taskScheduler.submitTasks(attempt2)
val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten
assert(1 === taskDescriptions3.length)
val mgr = taskScheduler.taskSetManagerForTask(taskDescriptions3(0).taskId).get
val mgr = taskScheduler.taskIdToTaskSet(taskDescriptions3(0).taskId)
assert(mgr.taskSet.stageAttemptId === 1)
}

Expand All @@ -212,7 +213,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
assert(10 === taskDescriptions.length)

// now mark attempt 1 as a zombie
val mgr1 = taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId).get
val mgr1 = taskScheduler.stageIdToLatestTaskSet(attempt1.stageId)
mgr1.isZombie = true

// don't schedule anything on another resource offer
Expand All @@ -232,7 +233,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
assert(10 === taskDescriptions3.length)

taskDescriptions3.foreach{ task =>
val mgr = taskScheduler.taskSetManagerForTask(task.taskId).get
val mgr = taskScheduler.taskIdToTaskSet(task.taskId)
assert(mgr.taskSet.stageAttemptId === 1)
}
}
Expand Down