diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index af17b5d5d2571..8e1b3fa428d80 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -34,7 +34,6 @@ import akka.pattern.ask import akka.util.Timeout import org.apache.spark._ -import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD @@ -124,6 +123,10 @@ class DAGScheduler( /** If enabled, we may run certain actions like take() and first() locally. */ private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) + /** Broadcast the serialized tasks only when they are bigger than it */ + private val broadcastTaskMinSize = + sc.getConf.getInt("spark.scheduler.broadcastTaskMinSizeKB", 8) * 1024 + /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) @@ -820,22 +823,29 @@ class DAGScheduler( listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. - // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast + // We serialize the task first, then broadcast it if the number of serialized bytes is + // larger than broadcastTaskMinSize, used to dispatch tasks to executors. Note that we broadcast // the serialized copy of the RDD and for each task we will deserialize it, which means each // task gets a different copy of the RDD. This provides stronger isolation between tasks that // might modify state of objects referenced in their closures. This is necessary in Hadoop // where the JobConf/Configuration object is not thread-safe. - var taskBinary: Broadcast[Array[Byte]] = null + var taskBinary: Array[Byte] = null try { // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). // For ResultTask, serialize and broadcast (rdd, func). - val taskBinaryBytes: Array[Byte] = + taskBinary = if (stage.isShuffleMap) { closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array() } else { closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array() } - taskBinary = sc.broadcast(taskBinaryBytes) + if (taskBinary.size > broadcastTaskMinSize) { + logInfo(s"Create broadcast for taskBinary: ${taskBinary.size} > $broadcastTaskMinSize") + val broadcasted = sc.broadcast(taskBinary) + // use stage to track the life cycle of broadcast + stage.broadcastedTaskBinary = broadcasted + taskBinary = closureSerializer.serialize(broadcasted: AnyRef).array() + } } catch { // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 4a9ff918afe25..16957ef21fed3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD * See [[Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param taskBinary broadcasted version of the serialized RDD and the function to apply on each + * @param taskBinary serialized RDD and the function to apply on each * partition of the given RDD. Once deserialized, the type should be * (RDD[T], (TaskContext, Iterator[T]) => U). * @param partition partition of the RDD this task is associated with @@ -41,7 +41,7 @@ import org.apache.spark.rdd.RDD */ private[spark] class ResultTask[T, U]( stageId: Int, - taskBinary: Broadcast[Array[Byte]], + taskBinary: Array[Byte], partition: Partition, @transient locs: Seq[TaskLocation], val outputId: Int) @@ -54,8 +54,14 @@ private[spark] class ResultTask[T, U]( override def runTask(context: TaskContext): U = { // Deserialize the RDD and the func using the broadcast variables. val ser = SparkEnv.get.closureSerializer.newInstance() - val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( - ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + val obj = ser.deserialize[AnyRef](ByteBuffer.wrap(taskBinary), + Thread.currentThread.getContextClassLoader) + val (rdd, func) = obj match { + case (rdd: RDD[T], func: ((TaskContext, Iterator[T]) => U)) => (rdd, func) + case b: Broadcast[Array[Byte]] => + ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( + ByteBuffer.wrap(b.value), Thread.currentThread.getContextClassLoader) + } metrics = Some(context.taskMetrics) func(context, rdd.iterator(partition, context)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 79709089c0da4..f9c04e6abd3f5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -33,14 +33,14 @@ import org.apache.spark.shuffle.ShuffleWriter * See [[org.apache.spark.scheduler.Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param taskBinary broadcast version of of the RDD and the ShuffleDependency. Once deserialized, + * @param taskBinary serialized RDD and the ShuffleDependency. Once deserialized, * the type should be (RDD[_], ShuffleDependency[_, _, _]). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling */ private[spark] class ShuffleMapTask( stageId: Int, - taskBinary: Broadcast[Array[Byte]], + taskBinary: Array[Byte], partition: Partition, @transient private var locs: Seq[TaskLocation]) extends Task[MapStatus](stageId, partition.index) with Logging { @@ -57,8 +57,14 @@ private[spark] class ShuffleMapTask( override def runTask(context: TaskContext): MapStatus = { // Deserialize the RDD using the broadcast variable. val ser = SparkEnv.get.closureSerializer.newInstance() - val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( - ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + val obj = ser.deserialize[AnyRef](ByteBuffer.wrap(taskBinary), + Thread.currentThread.getContextClassLoader) + val (rdd, dep) = obj match { + case (rdd: RDD[_], dep: ShuffleDependency[_, _, _]) => (rdd, dep) + case b: Broadcast[Array[Byte]] => + ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( + ByteBuffer.wrap(b.value), Thread.currentThread.getContextClassLoader) + } metrics = Some(context.taskMetrics) var writer: ShuffleWriter[Any, Any] = null diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index cc13f57a49b89..cfa80c8d04661 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import scala.collection.mutable.HashSet import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite @@ -69,6 +70,12 @@ private[spark] class Stage( var resultOfJob: Option[ActiveJob] = None var pendingTasks = new HashSet[Task[_]] + /** + * This is used to track the life cycle of broadcast, + * then it can be release by GC once the stage is released + */ + var broadcastedTaskBinary: Broadcast[Array[Byte]] = _ + private var nextAttemptId = 0 val name = callSite.shortForm diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 561a5e9cd90c4..8c52906bc187e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -43,7 +43,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() val task = new ResultTask[String, String]( - 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) + 0, closureSerializer.serialize((rdd, func)).array, rdd.partitions(0), Seq(), 0) intercept[RuntimeException] { task.run(0) }