diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala index 5e62c8468f00..49174f7c026c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala @@ -52,8 +52,8 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { val runningTasks2 = s2.runningTasks val s1Needy = runningTasks1 < minShare1 val s2Needy = runningTasks2 < minShare2 - val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0).toDouble - val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble + val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0) + val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0) val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble var compare:Int = 0 diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 633e892554c5..132cf8aa5fd0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -28,6 +28,8 @@ import scala.collection.mutable.HashSet import scala.language.postfixOps import scala.util.Random +import akka.actor.Cancellable + import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.SchedulingMode.SchedulingMode @@ -89,11 +91,11 @@ private[spark] class TaskSchedulerImpl( // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host - protected val executorsByHost = new HashMap[String, HashSet[String]] + protected[scheduler] val executorsByHost = new HashMap[String, HashSet[String]] protected val hostsByRack = new HashMap[String, HashSet[String]] - protected val executorIdToHost = new HashMap[String, String] + protected[scheduler] val executorIdToHost = new HashMap[String, String] // Listener object to pass upcalls into var dagScheduler: DAGScheduler = null @@ -116,6 +118,8 @@ private[spark] class TaskSchedulerImpl( // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) + private[spark] var speculativeExecTask: Cancellable = _ + override def setDAGScheduler(dagScheduler: DAGScheduler) { this.dagScheduler = dagScheduler } @@ -139,11 +143,10 @@ private[spark] class TaskSchedulerImpl( override def start() { backend.start() - if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") import sc.env.actorSystem.dispatcher - sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, + speculativeExecTask = sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, SPECULATION_INTERVAL milliseconds) { Utils.tryOrExit { checkSpeculatableTasks() } } @@ -397,7 +400,9 @@ private[spark] class TaskSchedulerImpl( taskResultGetter.stop() } starvationTimer.cancel() - + if (speculativeExecTask != null) { + speculativeExecTask.cancel() + } // sleeping for an arbitrary 1 seconds to ensure that messages are sent out. Thread.sleep(1000L) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d9d53faf843f..27ae2ce5d1a4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -412,7 +412,6 @@ private[spark] class TaskSetManager( allowedLocality = maxLocality } } - findTask(execId, host, allowedLocality) match { case Some((index, taskLocality, speculative)) => { // Found a task; do some bookkeeping and return a task description diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 0a7cb69416a0..dc21d7f5711b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -39,4 +39,15 @@ object FakeTask { } new TaskSet(tasks, 0, 0, 0, null) } + + def createTaskSet(numTasks: Int, stageId: Int, attempt: Int, + prefLocs: Seq[TaskLocation]*): TaskSet = { + if (prefLocs.size != 0 && prefLocs.size != numTasks) { + throw new IllegalArgumentException("Wrong number of task locations") + } + val tasks = Array.tabulate[Task[_]](numTasks) { i => + new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) + } + new TaskSet(tasks, stageId, attempt, 0, null) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulingModeSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulingModeSuite.scala new file mode 100644 index 000000000000..eee77968ef75 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulingModeSuite.scala @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.Properties + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.mock.EasyMockSugar + +import org.apache.spark.{SparkContext, Logging, LocalSparkContext} + +class FakeTaskSetManager( + initPriority: Int, + initStageId: Int, + initNumTasks: Int, + taskScheduler: TaskSchedulerImpl, + taskSet: TaskSet) + extends TaskSetManager(taskScheduler, taskSet, 0) { + + parent = null + weight = 1 + minShare = 2 + priority = initPriority + stageId = initStageId + name = "TaskSet_"+stageId + override val numTasks = initNumTasks + tasksSuccessful = 0 + + var numRunningTasks = 0 + override def runningTasks = numRunningTasks + + def increaseRunningTasks(taskNum: Int) { + numRunningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + def decreaseRunningTasks(taskNum: Int) { + numRunningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + override def addSchedulable(schedulable: Schedulable) { + } + + override def removeSchedulable(schedulable: Schedulable) { + } + + override def getSchedulableByName(name: String): Schedulable = { + null + } + + override def executorLost(executorId: String, host: String): Unit = { + } + + override def resourceOffer( + execId: String, + host: String, + maxLocality: TaskLocality.TaskLocality) + : Option[TaskDescription] = + { + if (tasksSuccessful + numRunningTasks < numTasks) { + increaseRunningTasks(1) + Some(new TaskDescription(0, execId, "task 0:0", 0, null)) + } else { + None + } + } + + override def checkSpeculatableTasks(): Boolean = { + true + } + + def taskFinished() { + decreaseRunningTasks(1) + tasksSuccessful +=1 + if (tasksSuccessful == numTasks) { + parent.removeSchedulable(this) + } + } + + def abort() { + decreaseRunningTasks(numRunningTasks) + parent.removeSchedulable(this) + } +} + + +class SchedulingModeSuite extends FunSuite with BeforeAndAfter with EasyMockSugar + with LocalSparkContext with Logging { + + private var taskScheduler: TaskSchedulerImpl = _ + + before { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] + } + + def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl, + taskSet: TaskSet): FakeTaskSetManager = { + new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) + } + + def resourceOffer(rootPool: Pool): Int = { + val taskSetQueue = rootPool.getSortedTaskSetQueue + /* Just for Test*/ + for (manager <- taskSetQueue) { + logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format( + manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks)) + } + for (taskSet <- taskSetQueue) { + taskSet.resourceOffer("execId_1", "hostname_1", TaskLocality.ANY) match { + case Some(task) => + return taskSet.stageId + case None => {} + } + } + -1 + } + + def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) { + assert(resourceOffer(rootPool) === expectedTaskSetId) + } + + test("FIFO Scheduler Test") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + val taskSet = FakeTask.createTaskSet(1) + + val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0) + val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) + schedulableBuilder.buildPools() + + val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, taskScheduler, taskSet) + val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, taskScheduler, taskSet) + val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, taskScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager0, null) + schedulableBuilder.addTaskSetManager(taskSetManager1, null) + schedulableBuilder.addTaskSetManager(taskSetManager2, null) + + checkTaskSetId(rootPool, 0) + resourceOffer(rootPool) + checkTaskSetId(rootPool, 1) + resourceOffer(rootPool) + taskSetManager1.abort() + checkTaskSetId(rootPool, 2) + } + + test("Fair Scheduler Test") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + val taskSet = FakeTask.createTaskSet(1) + + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.scheduler.allocation.file", xmlPath) + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) + schedulableBuilder.buildPools() + + assert(rootPool.getSchedulableByName("default") != null) + assert(rootPool.getSchedulableByName("1") != null) + assert(rootPool.getSchedulableByName("2") != null) + assert(rootPool.getSchedulableByName("3") != null) + assert(rootPool.getSchedulableByName("1").minShare === 2) + assert(rootPool.getSchedulableByName("1").weight === 1) + assert(rootPool.getSchedulableByName("2").minShare === 3) + assert(rootPool.getSchedulableByName("2").weight === 1) + assert(rootPool.getSchedulableByName("3").minShare === 0) + assert(rootPool.getSchedulableByName("3").weight === 1) + + val properties1 = new Properties() + properties1.setProperty("spark.scheduler.pool","1") + val properties2 = new Properties() + properties2.setProperty("spark.scheduler.pool","2") + + val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, taskScheduler, taskSet) + val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, taskScheduler, taskSet) + val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, taskScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager10, properties1) + schedulableBuilder.addTaskSetManager(taskSetManager11, properties1) + schedulableBuilder.addTaskSetManager(taskSetManager12, properties1) + + val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, taskScheduler, taskSet) + val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, taskScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager23, properties2) + schedulableBuilder.addTaskSetManager(taskSetManager24, properties2) + + checkTaskSetId(rootPool, 0) + checkTaskSetId(rootPool, 3) + checkTaskSetId(rootPool, 3) + checkTaskSetId(rootPool, 1) + checkTaskSetId(rootPool, 4) + checkTaskSetId(rootPool, 2) + checkTaskSetId(rootPool, 2) + checkTaskSetId(rootPool, 4) + + taskSetManager12.taskFinished() + assert(rootPool.getSchedulableByName("1").runningTasks === 3) + taskSetManager24.abort() + assert(rootPool.getSchedulableByName("2").runningTasks === 2) + } + + test("Nested Pool Test") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + val taskSet = FakeTask.createTaskSet(1) + + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1) + val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1) + rootPool.addSchedulable(pool0) + rootPool.addSchedulable(pool1) + + val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2) + val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1) + pool0.addSchedulable(pool00) + pool0.addSchedulable(pool01) + + val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2) + val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1) + pool1.addSchedulable(pool10) + pool1.addSchedulable(pool11) + + val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, taskScheduler, taskSet) + val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, taskScheduler, taskSet) + pool00.addSchedulable(taskSetManager000) + pool00.addSchedulable(taskSetManager001) + + val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, taskScheduler, taskSet) + val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, taskScheduler, taskSet) + pool01.addSchedulable(taskSetManager010) + pool01.addSchedulable(taskSetManager011) + + val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, taskScheduler, taskSet) + val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, taskScheduler, taskSet) + pool10.addSchedulable(taskSetManager100) + pool10.addSchedulable(taskSetManager101) + + val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, taskScheduler, taskSet) + val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, taskScheduler, taskSet) + pool11.addSchedulable(taskSetManager110) + pool11.addSchedulable(taskSetManager111) + + checkTaskSetId(rootPool, 0) + checkTaskSetId(rootPool, 4) + checkTaskSetId(rootPool, 6) + checkTaskSetId(rootPool, 2) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 7532da88c606..a222be0d8045 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -17,260 +17,141 @@ package org.apache.spark.scheduler -import java.util.Properties +import scala.collection.mutable.{ArrayBuffer, HashSet} -import org.scalatest.FunSuite +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.mock.EasyMockSugar import org.apache.spark._ class FakeSchedulerBackend extends SchedulerBackend { + def start() {} def stop() {} - def reviveOffers() {} + def reviveOffers() { + + } def defaultParallelism() = 1 } -class FakeTaskSetManager( - initPriority: Int, - initStageId: Int, - initNumTasks: Int, - taskScheduler: TaskSchedulerImpl, - taskSet: TaskSet) - extends TaskSetManager(taskScheduler, taskSet, 0) { - - parent = null - weight = 1 - minShare = 2 - priority = initPriority - stageId = initStageId - name = "TaskSet_"+stageId - override val numTasks = initNumTasks - tasksSuccessful = 0 +class TaskSchedulerImplSuite extends FunSuite with BeforeAndAfter with EasyMockSugar + with LocalSparkContext with Logging { - var numRunningTasks = 0 - override def runningTasks = numRunningTasks + private var taskScheduler: TaskSchedulerImpl = _ - def increaseRunningTasks(taskNum: Int) { - numRunningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } + before { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] } - def decreaseRunningTasks(taskNum: Int) { - numRunningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) + test("ClusterSchedulerBackend is started and stopped properly") { + taskScheduler = new TaskSchedulerImpl(sc) + val mockSchedulerBackend = mock[SchedulerBackend] + taskScheduler.initialize(mockSchedulerBackend) + expecting { + mockSchedulerBackend.start() + mockSchedulerBackend.stop() } - } - - override def addSchedulable(schedulable: Schedulable) { - } - - override def removeSchedulable(schedulable: Schedulable) { - } - - override def getSchedulableByName(name: String): Schedulable = { - null - } - - override def executorLost(executorId: String, host: String): Unit = { - } - - override def resourceOffer( - execId: String, - host: String, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = - { - if (tasksSuccessful + numRunningTasks < numTasks) { - increaseRunningTasks(1) - Some(new TaskDescription(0, execId, "task 0:0", 0, null)) - } else { - None + whenExecuting(mockSchedulerBackend) { + taskScheduler.start() + taskScheduler.stop() } } - override def checkSpeculatableTasks(): Boolean = { - true + test("speculative execution task is started and stopped properly") { + sc.conf.set("spark.speculation", "true") + taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.backend = new FakeSchedulerBackend + taskScheduler.start() + assert(taskScheduler.speculativeExecTask != null) + taskScheduler.stop() + assert(taskScheduler.speculativeExecTask.isCancelled === true) } - def taskFinished() { - decreaseRunningTasks(1) - tasksSuccessful +=1 - if (tasksSuccessful == numTasks) { - parent.removeSchedulable(this) + test("tasks are successfully submitted and reviveOffer is called") { + taskScheduler = new TaskSchedulerImpl(sc) + val mockSchedulerBackend = mock[SchedulerBackend] + taskScheduler.initialize(mockSchedulerBackend) + val taskSet1 = FakeTask.createTaskSet(1, 0, 0) + expecting { + mockSchedulerBackend.reviveOffers() } - } - - def abort() { - decreaseRunningTasks(numRunningTasks) - parent.removeSchedulable(this) - } -} - -class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Logging { - - def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl, - taskSet: TaskSet): FakeTaskSetManager = { - new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) - } - - def resourceOffer(rootPool: Pool): Int = { - val taskSetQueue = rootPool.getSortedTaskSetQueue - /* Just for Test*/ - for (manager <- taskSetQueue) { - logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format( - manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks)) + whenExecuting(mockSchedulerBackend) { + taskScheduler.submitTasks(taskSet1) } - for (taskSet <- taskSetQueue) { - taskSet.resourceOffer("execId_1", "hostname_1", TaskLocality.ANY) match { - case Some(task) => - return taskSet.stageId - case None => {} - } + assert(taskScheduler.activeTaskSets(taskSet1.id) != null) + } + + test("running tasks are cancelled and the involved stage is aborted when calling cancelTasks") { + taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.setDAGScheduler(sc.dagScheduler) + val taskSet1 = FakeTask.createTaskSet(1, 0, 0) + val tm = new TaskSetManager(taskScheduler, taskSet1, 1) + val mockSchedulerBackend = mock[SchedulerBackend] + taskScheduler.initialize(mockSchedulerBackend) + tm.runningTasksSet += 0 + taskScheduler.activeTaskSets += taskSet1.id -> tm + taskScheduler.schedulableBuilder.addTaskSetManager(tm, tm.taskSet.properties) + val dummyTaskId = taskScheduler.newTaskId() + taskScheduler.taskIdToExecutorId += dummyTaskId -> "exec1" + expecting { + mockSchedulerBackend.killTask(dummyTaskId, "exec1", false) + } + whenExecuting(mockSchedulerBackend) { + taskScheduler.cancelTasks(0, false) } - -1 - } - - def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) { - assert(resourceOffer(rootPool) === expectedTaskSetId) - } - - test("FIFO Scheduler Test") { - sc = new SparkContext("local", "TaskSchedulerImplSuite") - val taskScheduler = new TaskSchedulerImpl(sc) - val taskSet = FakeTask.createTaskSet(1) - - val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0) - val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) - schedulableBuilder.buildPools() - - val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, taskScheduler, taskSet) - val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, taskScheduler, taskSet) - val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, taskScheduler, taskSet) - schedulableBuilder.addTaskSetManager(taskSetManager0, null) - schedulableBuilder.addTaskSetManager(taskSetManager1, null) - schedulableBuilder.addTaskSetManager(taskSetManager2, null) - - checkTaskSetId(rootPool, 0) - resourceOffer(rootPool) - checkTaskSetId(rootPool, 1) - resourceOffer(rootPool) - taskSetManager1.abort() - checkTaskSetId(rootPool, 2) - } - - test("Fair Scheduler Test") { - sc = new SparkContext("local", "TaskSchedulerImplSuite") - val taskScheduler = new TaskSchedulerImpl(sc) - val taskSet = FakeTask.createTaskSet(1) - - val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - System.setProperty("spark.scheduler.allocation.file", xmlPath) - val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) - val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) - schedulableBuilder.buildPools() - - assert(rootPool.getSchedulableByName("default") != null) - assert(rootPool.getSchedulableByName("1") != null) - assert(rootPool.getSchedulableByName("2") != null) - assert(rootPool.getSchedulableByName("3") != null) - assert(rootPool.getSchedulableByName("1").minShare === 2) - assert(rootPool.getSchedulableByName("1").weight === 1) - assert(rootPool.getSchedulableByName("2").minShare === 3) - assert(rootPool.getSchedulableByName("2").weight === 1) - assert(rootPool.getSchedulableByName("3").minShare === 0) - assert(rootPool.getSchedulableByName("3").weight === 1) - - val properties1 = new Properties() - properties1.setProperty("spark.scheduler.pool","1") - val properties2 = new Properties() - properties2.setProperty("spark.scheduler.pool","2") - - val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, taskScheduler, taskSet) - val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, taskScheduler, taskSet) - val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, taskScheduler, taskSet) - schedulableBuilder.addTaskSetManager(taskSetManager10, properties1) - schedulableBuilder.addTaskSetManager(taskSetManager11, properties1) - schedulableBuilder.addTaskSetManager(taskSetManager12, properties1) - - val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, taskScheduler, taskSet) - val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, taskScheduler, taskSet) - schedulableBuilder.addTaskSetManager(taskSetManager23, properties2) - schedulableBuilder.addTaskSetManager(taskSetManager24, properties2) - - checkTaskSetId(rootPool, 0) - checkTaskSetId(rootPool, 3) - checkTaskSetId(rootPool, 3) - checkTaskSetId(rootPool, 1) - checkTaskSetId(rootPool, 4) - checkTaskSetId(rootPool, 2) - checkTaskSetId(rootPool, 2) - checkTaskSetId(rootPool, 4) - - taskSetManager12.taskFinished() - assert(rootPool.getSchedulableByName("1").runningTasks === 3) - taskSetManager24.abort() - assert(rootPool.getSchedulableByName("2").runningTasks === 2) } - test("Nested Pool Test") { - sc = new SparkContext("local", "TaskSchedulerImplSuite") - val taskScheduler = new TaskSchedulerImpl(sc) - val taskSet = FakeTask.createTaskSet(1) - - val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) - val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1) - val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1) - rootPool.addSchedulable(pool0) - rootPool.addSchedulable(pool1) - - val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2) - val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1) - pool0.addSchedulable(pool00) - pool0.addSchedulable(pool01) - - val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2) - val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1) - pool1.addSchedulable(pool10) - pool1.addSchedulable(pool11) - - val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, taskScheduler, taskSet) - val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, taskScheduler, taskSet) - pool00.addSchedulable(taskSetManager000) - pool00.addSchedulable(taskSetManager001) - - val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, taskScheduler, taskSet) - val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, taskScheduler, taskSet) - pool01.addSchedulable(taskSetManager010) - pool01.addSchedulable(taskSetManager011) - - val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, taskScheduler, taskSet) - val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, taskScheduler, taskSet) - pool10.addSchedulable(taskSetManager100) - pool10.addSchedulable(taskSetManager101) - - val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, taskScheduler, taskSet) - val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, taskScheduler, taskSet) - pool11.addSchedulable(taskSetManager110) - pool11.addSchedulable(taskSetManager111) - - checkTaskSetId(rootPool, 0) - checkTaskSetId(rootPool, 4) - checkTaskSetId(rootPool, 6) - checkTaskSetId(rootPool, 2) + test("the data structures are updated properly when the task state is updated") { + val taskSet1 = FakeTask.createTaskSet(1, 0, 0) + val tm = new TaskSetManager(taskScheduler, taskSet1, 1) + taskScheduler.schedulableBuilder.addTaskSetManager(tm, tm.taskSet.properties) + tm.taskInfos += 0.toLong -> new TaskInfo(0, 0, 0, 0, "exec1", "host1", TaskLocality.ANY, false) + taskScheduler.activeExecutorIds += "exec1" + taskScheduler.executorIdToHost += "exec1" -> "host1" + taskScheduler.executorsByHost += "host1" -> HashSet[String]("exec1") + taskScheduler.taskIdToExecutorId += 0.toLong -> "exec1" + taskScheduler.taskIdToTaskSetId += 0.toLong -> taskSet1.id + taskScheduler.activeTaskSets += taskSet1.id -> tm + assert(taskScheduler.activeExecutorIds.contains("exec1") === true) + assert(taskScheduler.executorIdToHost.contains("exec1") === true) + assert(taskScheduler.executorsByHost.contains("host1") === true) + assert(taskScheduler.taskIdToTaskSetId.contains(0) === true) + taskScheduler.statusUpdate(0, TaskState.LOST, null) + assert(taskScheduler.activeExecutorIds.contains("exec1") === false) + assert(taskScheduler.executorIdToHost.contains("exec1") === false) + assert(taskScheduler.executorsByHost.contains("host1") === false) + assert(taskScheduler.taskIdToTaskSetId.contains(0) === false) + } + + test("task result is correctly enqueued according to the updated state") { + taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.setDAGScheduler(sc.dagScheduler) + val taskSet1 = FakeTask.createTaskSet(2, 0, 0) + val tm = new TaskSetManager(taskScheduler, taskSet1, 1) + tm.taskInfos += 0.toLong -> new TaskInfo(0, 0, 0, 0, "exec1", "host1", TaskLocality.ANY, false) + tm.taskInfos += 1.toLong -> new TaskInfo(0, 1, 0, 0, "exec1", "host1", TaskLocality.ANY, false) + taskScheduler.activeExecutorIds += "exec1" + taskScheduler.executorIdToHost += "exec1" -> "host1" + taskScheduler.executorsByHost += "host1" -> HashSet[String]("exec1") + taskScheduler.taskIdToExecutorId += 0.toLong -> "exec1" + taskScheduler.taskIdToExecutorId += 1.toLong -> "exec1" + taskScheduler.taskIdToTaskSetId += 0.toLong -> taskSet1.id + taskScheduler.taskIdToTaskSetId += 1.toLong -> taskSet1.id + taskScheduler.activeTaskSets += taskSet1.id -> tm + taskScheduler.taskResultGetter = mock[TaskResultGetter] + expecting { + taskScheduler.taskResultGetter.enqueueFailedTask(tm, 0, TaskState.FAILED, null) + taskScheduler.taskResultGetter.enqueueSuccessfulTask(tm, 1, null) + } + whenExecuting(taskScheduler.taskResultGetter) { + taskScheduler.statusUpdate(0, TaskState.FAILED, null) + taskScheduler.statusUpdate(1, TaskState.FINISHED, null) + } } test("Scheduler does not always schedule tasks on the same workers") { - sc = new SparkContext("local", "TaskSchedulerImplSuite") - val taskScheduler = new TaskSchedulerImpl(sc) taskScheduler.initialize(new FakeSchedulerBackend) - // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. - val dagScheduler = new DAGScheduler(sc, taskScheduler) { - override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} - override def executorAdded(execId: String, host: String) {} - } val numFreeCores = 1 val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores), @@ -298,11 +179,12 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin val taskCpus = 2 sc.conf.set("spark.task.cpus", taskCpus.toString) - val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler = new TaskSchedulerImpl(sc) taskScheduler.initialize(new FakeSchedulerBackend) // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. val dagScheduler = new DAGScheduler(sc, taskScheduler) { override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} } taskScheduler.setDAGScheduler(dagScheduler)