diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 30bceb47b9e7d..aa1120f7723e5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -93,6 +93,7 @@ private[spark] class TaskSchedulerImpl( val mapOutputTracker = SparkEnv.get.mapOutputTracker var schedulableBuilder: SchedulableBuilder = null + var rootPool: Pool = null // default scheduler is FIFO val schedulingMode: SchedulingMode = SchedulingMode.withName( @@ -196,9 +197,8 @@ private[spark] class TaskSchedulerImpl( * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so * that tasks are balanced across the cluster. */ - def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { + def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescWithoutSerializedTask]] = synchronized { SparkEnv.set(sc.env) - // Mark each slave as alive and remember its hostname for (o <- offers) { executorIdToHost(o.executorId) = o.host @@ -211,7 +211,7 @@ private[spark] class TaskSchedulerImpl( // Randomly shuffle offers to avoid always placing tasks on the same set of workers. val shuffledOffers = Random.shuffle(offers) // Build a list of tasks to assign to each worker. - val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescWithoutSerializedTask](o.cores)) val availableCpus = shuffledOffers.map(o => o.cores).toArray val sortedTaskSets = rootPool.getSortedTaskSetQueue() for (taskSet <- sortedTaskSets) { @@ -228,9 +228,9 @@ private[spark] class TaskSchedulerImpl( for (i <- 0 until shuffledOffers.size) { val execId = shuffledOffers(i).executorId val host = shuffledOffers(i).host - for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) { - tasks(i) += task - val tid = task.taskId + for (taskNoSer <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) { + tasks(i) += taskNoSer + val tid = taskNoSer.taskId taskIdToTaskSetId(tid) = taskSet.taskSet.id taskIdToExecutorId(tid) = execId activeExecutorIds += execId diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a73343c1c0826..a7cc574a7ebc1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -32,6 +32,10 @@ import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.{Clock, SystemClock} +private[spark] class TaskDescWithoutSerializedTask( + val taskId: Long, val executorId: String, val taskName: String, val index: Int, + val taskObject: Task[_]) + /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of * each task, retries tasks if they fail (up to a limited number of times), and @@ -386,7 +390,7 @@ private[spark] class TaskSetManager( host: String, availableCpus: Int, maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = + : Option[TaskDescWithoutSerializedTask] = { if (!isZombie && availableCpus >= CPUS_PER_TASK) { val curTime = clock.getTime() @@ -412,19 +416,12 @@ private[spark] class TaskSetManager( // Update our locality level for delay scheduling currentLocalityIndex = getLocalityIndex(taskLocality) lastLaunchTime = curTime - // Serialize and return the task - val startTime = clock.getTime() - // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here - // we assume the task can be serialized without exceptions. - val serializedTask = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) - val timeTaken = clock.getTime() - startTime + addRunningTask(taskId) - logInfo("Serialized task %s:%d as %d bytes in %d ms".format( - taskSet.id, index, serializedTask.limit, timeTaken)) + val taskName = "task %s:%d".format(taskSet.id, index) sched.dagScheduler.taskStarted(task, info) - return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) + return Some(new TaskDescWithoutSerializedTask(taskId, execId, taskName, index, task)) } case _ => } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index fad03731572e7..877346e679d2a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -27,10 +27,12 @@ import akka.actor._ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} -import org.apache.spark.{Logging, SparkException, TaskState} +import org.apache.spark.{Logging, SparkException, TaskState, SparkEnv} import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.scheduler.{TaskDescWithoutSerializedTask,Task} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.serializer.SerializerInstance /** * A scheduler backend that waits for coarse grained executors to connect to it through Akka. @@ -48,6 +50,26 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A var totalCoreCount = new AtomicInteger(0) val conf = scheduler.sc.conf private val timeout = AkkaUtils.askTimeout(conf) + val serializeWorkerPool = Utils.newDaemonFixedThreadPool( + conf.getInt("spark.scheduler.task.serialize.threads", 4), "task-serialization") + + val env = SparkEnv.get + protected val serializer = new ThreadLocal[SerializerInstance] { + override def initialValue(): SerializerInstance = { + env.closureSerializer.newInstance() + } + } + + class TaskCGSerializedRunner(execActor: ActorRef, + taskNoSer: TaskDescWithoutSerializedTask, + scheduler: TaskSchedulerImpl) + extends Runnable { + override def run() { + // Serialize and return the task + val task = Utils.serializeTask(taskNoSer, scheduler.sc, serializer.get()) + execActor ! LaunchTask(task) + } + } class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { private val executorActor = new HashMap[String, ActorRef] @@ -57,6 +79,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A private val totalCores = new HashMap[String, Int] private val addressToExecutorId = new HashMap[Address, String] + override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) @@ -138,10 +161,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } // Launch tasks returned by a set of resource offers - def launchTasks(tasks: Seq[Seq[TaskDescription]]) { - for (task <- tasks.flatten) { - freeCores(task.executorId) -= 1 - executorActor(task.executorId) ! LaunchTask(task) + def launchTasks(tasks: Seq[Seq[TaskDescWithoutSerializedTask]]) { + for (taskNoSer <- tasks.flatten) { + freeCores(taskNoSer.executorId) -= 1 + serializeWorkerPool.execute( + new TaskCGSerializedRunner(executorActor(taskNoSer.executorId), taskNoSer, scheduler) + ) } } @@ -191,6 +216,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } override def stop() { + if (serializeWorkerPool != null) { + serializeWorkerPool.shutdownNow() + } stopExecutors() try { if (driverActor != null) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 4092dd04b112b..f824db4514c76 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -29,9 +29,12 @@ import org.apache.mesos.{Scheduler => MScheduler} import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} -import org.apache.spark.{Logging, SparkContext, SparkException, TaskState} -import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.{SparkEnv, Logging, SparkContext, SparkException, TaskState} +import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost} +import org.apache.spark.scheduler.{TaskDescWithoutSerializedTask, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.util.Utils +import org.apache.spark.serializer.SerializerInstance + /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a @@ -62,6 +65,30 @@ private[spark] class MesosSchedulerBackend( var classLoader: ClassLoader = null + val serializeWorkerPool = Utils.newDaemonFixedThreadPool( + scheduler.sc.conf.getInt("spark.scheduler.task.serialize.threads", 4), "task-serialization") + + val env = SparkEnv.get + protected val serializer = new ThreadLocal[SerializerInstance] { + override def initialValue(): SerializerInstance = { + env.closureSerializer.newInstance() + } + } + + class TaskMesosSerializedRunner(taskNoSer: TaskDescWithoutSerializedTask, + taskList: JList[MesosTaskInfo], + slaveId: String, + scheduler: TaskSchedulerImpl) + extends Runnable { + override def run() { + // Serialize and return the task + val task = Utils.serializeTask(taskNoSer, scheduler.sc, serializer.get()) + taskList.synchronized { + taskList.add(createMesosTask(task, slaveId)) + } + } + } + override def start() { synchronized { classLoader = Thread.currentThread.getContextClassLoader @@ -213,9 +240,11 @@ private[spark] class MesosSchedulerBackend( val slaveId = offers(offerNum).getSlaveId.getValue slaveIdsWithExecutors += slaveId mesosTasks(offerNum) = new JArrayList[MesosTaskInfo](taskList.size) - for (taskDesc <- taskList) { - taskIdToSlaveId(taskDesc.taskId) = slaveId - mesosTasks(offerNum).add(createMesosTask(taskDesc, slaveId)) + for (taskNoSer <- taskList) { + taskIdToSlaveId(taskNoSer.taskId) = slaveId + val taskList = mesosTasks(offerNum) + serializeWorkerPool.execute( + new TaskMesosSerializedRunner(taskNoSer, taskList, slaveId, scheduler)) } } } @@ -297,6 +326,9 @@ private[spark] class MesosSchedulerBackend( } override def stop() { + if (serializeWorkerPool != null) { + serializeWorkerPool.shutdownNow() + } if (driver != null) { driver.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 16e2f5cf3076d..5fe5d2b23433b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -25,6 +25,7 @@ import org.apache.spark.{Logging, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.util.Utils private case class ReviveOffers() @@ -46,6 +47,7 @@ private[spark] class LocalActor( private val localExecutorId = "localhost" private val localExecutorHostname = "localhost" + val ser = scheduler.sc.env.closureSerializer.newInstance() val executor = new Executor( localExecutorId, localExecutorHostname, scheduler.conf.getAll, isLocal = true) @@ -67,8 +69,9 @@ private[spark] class LocalActor( def reviveOffers() { val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) - for (task <- scheduler.resourceOffers(offers).flatten) { + for (taskNoSer <- scheduler.resourceOffers(offers).flatten) { freeCores -= 1 + val task = Utils.serializeTask(taskNoSer, scheduler.sc, ser) executor.launchTask(executorBackend, task.taskId, task.serializedTask) } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 13d9dbdd9af2d..92bc5225f75ca 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -34,9 +34,11 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.json4s._ +import org.apache.spark.SparkContext import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} +import org.apache.spark.scheduler.{TaskDescWithoutSerializedTask, TaskDescription, Task} /** * Various utility methods used by Spark. @@ -149,6 +151,21 @@ private[spark] object Utils extends Logging { buf } + def serializeTask(taskNoSer: TaskDescWithoutSerializedTask, sc: SparkContext, + serializer: SerializerInstance) : TaskDescription = { + val startTime = System.currentTimeMillis() + // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here + // we assume the task can be serialized without exceptions. + val serializedTask = Task.serializeWithDependencies( + taskNoSer.taskObject, sc.addedFiles, sc.addedJars, serializer) + val timeTaken = System.currentTimeMillis() - startTime + logInfo("Serialized task %s as %d bytes in %d ms".format( + taskNoSer.taskName, serializedTask.limit, timeTaken)) + val task = new TaskDescription(taskNoSer.taskId, taskNoSer.executorId, + taskNoSer.taskName, taskNoSer.index, serializedTask) + task + } + private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() // Register the path to be deleted via shutdown hook diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 9274e01632d58..22e4338304869 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -82,11 +82,11 @@ class FakeTaskSetManager( host: String, availableCpus: Int, maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = + : Option[TaskDescWithoutSerializedTask] = { if (tasksSuccessful + numRunningTasks < numTasks) { increaseRunningTasks(1) - Some(new TaskDescription(0, execId, "task 0:0", 0, null)) + Some(new TaskDescWithoutSerializedTask(0, execId, "task 0:0", 0, null)) } else { None }