diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index a34b67db388f..380696bf35b5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -75,9 +75,11 @@ private[spark] class TaskSchedulerImpl( // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. - val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]] + /** Mapping from stage ID to the TaskSetManager for the most recent attempt for the stage.*/ + val stageIdToLatestTaskSet = new HashMap[Int, TaskSetManager] + /** Mapping from task ID to the TaskSetManager for that task. */ + val taskIdToTaskSet = new HashMap[Long, TaskSetManager] - val taskIdToStageIdAndAttempt = new HashMap[Long, (Int, Int)] val taskIdToExecutorId = new HashMap[Long, String] @volatile private var hasReceivedTask = false @@ -162,17 +164,14 @@ private[spark] class TaskSchedulerImpl( logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { val manager = createTaskSetManager(taskSet, maxTaskFailures) - val stage = taskSet.stageId - val stageTaskSets = - taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager]) - stageTaskSets(taskSet.stageAttemptId) = manager - val conflictingTaskSet = stageTaskSets.exists { case (_, ts) => - ts.taskSet != taskSet && !ts.isZombie - } - if (conflictingTaskSet) { - throw new IllegalStateException(s"more than one active taskSet for stage $stage:" + - s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}") + val stageId = taskSet.stageId + // Make sure there aren't already any running attempts for the stageId. + stageIdToLatestTaskSet.get(taskSet.stageId).filter(!_.isZombie).foreach { taskSetManager => + throw new IllegalStateException( + s"Already one active task set for stageId $stageId (attempt : " + + s"${taskSetManager.taskSet.stageAttemptId})") } + stageIdToLatestTaskSet.put(taskSet.stageId, manager) schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) if (!isLocal && !hasReceivedTask) { @@ -202,21 +201,19 @@ private[spark] class TaskSchedulerImpl( override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) - taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => - attempts.foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, - // simply abort the stage. - tsm.runningTasksSet.foreach { tid => - val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId, interruptThread) - } - tsm.abort("Stage %s cancelled".format(stageId)) - logInfo("Stage %d was cancelled".format(stageId)) + stageIdToLatestTaskSet.get(stageId).foreach { taskSetManager => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the stage. + taskSetManager.runningTasksSet.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId, interruptThread) } + taskSetManager.abort("Stage %s cancelled".format(stageId)) + logInfo("Stage %d was cancelled".format(stageId)) } } @@ -226,11 +223,11 @@ private[spark] class TaskSchedulerImpl( * cleaned up. */ def taskSetFinished(manager: TaskSetManager): Unit = synchronized { - taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage => - taskSetsForStage -= manager.taskSet.stageAttemptId - if (taskSetsForStage.isEmpty) { - taskSetsByStageIdAndAttempt -= manager.taskSet.stageId - } + // Remove the TaskSetManager from stageIdToRunningTaskSet if it is the currently active task + // manager (there may be a TaskSetManager corresponding to a newer attempt for the stage + // in stageIdToRunningTaskSet, in which case we don't want to remove it). + stageIdToLatestTaskSet.get(manager.stageId).foreach { activeTaskSetManager => + if (activeTaskSetManager == manager) stageIdToLatestTaskSet -= manager.stageId } manager.parent.removeSchedulable(manager) logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s" @@ -252,8 +249,7 @@ private[spark] class TaskSchedulerImpl( for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { tasks(i) += task val tid = task.taskId - taskIdToStageIdAndAttempt(tid) = - (taskSet.taskSet.stageId, taskSet.taskSet.stageAttemptId) + taskIdToTaskSet(tid) = taskSet taskIdToExecutorId(tid) = execId executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK @@ -337,10 +333,10 @@ private[spark] class TaskSchedulerImpl( failedExecutor = Some(execId) } } - taskSetManagerForTask(tid) match { + taskIdToTaskSet.get(tid) match { case Some(taskSet) => if (TaskState.isFinished(state)) { - taskIdToStageIdAndAttempt.remove(tid) + taskIdToTaskSet.remove(tid) taskIdToExecutorId.remove(tid) } if (state == TaskState.FINISHED) { @@ -379,12 +375,8 @@ private[spark] class TaskSchedulerImpl( val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized { taskMetrics.flatMap { case (id, metrics) => - for { - (stageId, stageAttemptId) <- taskIdToStageIdAndAttempt.get(id) - attempts <- taskSetsByStageIdAndAttempt.get(stageId) - taskSetMgr <- attempts.get(stageAttemptId) - } yield { - (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics) + taskIdToTaskSet.get(id).map { taskSetMgr => + (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics) } } } @@ -417,12 +409,9 @@ private[spark] class TaskSchedulerImpl( def error(message: String) { synchronized { - if (taskSetsByStageIdAndAttempt.nonEmpty) { + if (stageIdToLatestTaskSet.nonEmpty) { // Have each task set throw a SparkException with the error - for { - attempts <- taskSetsByStageIdAndAttempt.values - manager <- attempts.values - } { + for ((taskSetId, manager) <- stageIdToLatestTaskSet) { try { manager.abort(message) } catch { @@ -542,24 +531,6 @@ private[spark] class TaskSchedulerImpl( override def applicationId(): String = backend.applicationId() override def applicationAttemptId(): Option[String] = backend.applicationAttemptId() - - private[scheduler] def taskSetManagerForTask(taskId: Long): Option[TaskSetManager] = { - taskIdToStageIdAndAttempt.get(taskId).flatMap{ case (stageId, stageAttemptId) => - taskSetManagerForAttempt(stageId, stageAttemptId) - } - } - - private[scheduler] def taskSetManagerForAttempt( - stageId: Int, - stageAttemptId: Int): Option[TaskSetManager] = { - for { - attempts <- taskSetsByStageIdAndAttempt.get(stageId) - manager <- attempts.get(stageAttemptId) - } yield { - manager - } - } - } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 69cea0267438..aef8d6aaefdf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -191,7 +191,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp for (task <- tasks.flatten) { val serializedTask = ser.serialize(task) if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { - scheduler.taskSetManagerForTask(task.taskId).foreach { taskSet => + scheduler.taskIdToTaskSet.get(task.taskId).foreach { taskSet => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " + diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index cb0dce44536d..e47eea1fc7a7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -52,6 +52,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L taskScheduler.submitTasks(taskSet) val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten assert(1 === taskDescriptions.length) + taskScheduler.taskSetFinished(taskScheduler.stageIdToLatestTaskSet(taskSet.stageId)) taskDescriptions(0).executorId } val count = selectedExecutorIds.count(_ == workerOffers(0).executorId) @@ -78,6 +79,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L taskScheduler.submitTasks(taskSet) var taskDescriptions = taskScheduler.resourceOffers(zeroCoreWorkerOffers).flatten assert(0 === taskDescriptions.length) + taskScheduler.taskSetFinished(taskScheduler.stageIdToLatestTaskSet(taskSet.stageId)) // No tasks should run as we only have 1 core free. val numFreeCores = 1 @@ -86,6 +88,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L taskScheduler.submitTasks(taskSet) taskDescriptions = taskScheduler.resourceOffers(singleCoreWorkerOffers).flatten assert(0 === taskDescriptions.length) + taskScheduler.taskSetFinished(taskScheduler.stageIdToLatestTaskSet(taskSet.stageId)) // Now change the offers to have 2 cores in one executor and verify if it // is chosen. @@ -123,6 +126,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L // Even if one of the tasks has not-serializable tasks, the other task set should // still be processed without error taskScheduler.submitTasks(taskSet) + taskScheduler.taskSetFinished(taskScheduler.stageIdToLatestTaskSet(taskSet.stageId)) taskScheduler.submitTasks(FakeTask.createTaskSet(1)) taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten assert(taskDescriptions.map(_.executorId) === Seq("executor0")) @@ -144,13 +148,11 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) } // OK to submit multiple if previous attempts are all zombie - taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId) - .get.isZombie = true + taskScheduler.stageIdToLatestTaskSet(attempt1.stageId).isZombie = true taskScheduler.submitTasks(attempt2) val attempt3 = FakeTask.createTaskSet(1, 2) intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) } - taskScheduler.taskSetManagerForAttempt(attempt2.stageId, attempt2.stageAttemptId) - .get.isZombie = true + taskScheduler.stageIdToLatestTaskSet(attempt2.stageId).isZombie = true taskScheduler.submitTasks(attempt3) } @@ -174,8 +176,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L assert(1 === taskDescriptions.length) // now mark attempt 1 as a zombie - taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId) - .get.isZombie = true + taskScheduler.stageIdToLatestTaskSet(attempt1.stageId).isZombie = true // don't schedule anything on another resource offer val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten @@ -188,7 +189,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L taskScheduler.submitTasks(attempt2) val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten assert(1 === taskDescriptions3.length) - val mgr = taskScheduler.taskSetManagerForTask(taskDescriptions3(0).taskId).get + val mgr = taskScheduler.taskIdToTaskSet(taskDescriptions3(0).taskId) assert(mgr.taskSet.stageAttemptId === 1) } @@ -212,7 +213,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L assert(10 === taskDescriptions.length) // now mark attempt 1 as a zombie - val mgr1 = taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId).get + val mgr1 = taskScheduler.stageIdToLatestTaskSet(attempt1.stageId) mgr1.isZombie = true // don't schedule anything on another resource offer @@ -232,7 +233,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L assert(10 === taskDescriptions3.length) taskDescriptions3.foreach{ task => - val mgr = taskScheduler.taskSetManagerForTask(task.taskId).get + val mgr = taskScheduler.taskIdToTaskSet(task.taskId) assert(mgr.taskSet.stageAttemptId === 1) } }