Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import akka.pattern.ask
import akka.util.Timeout

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -124,6 +123,10 @@ class DAGScheduler(
/** If enabled, we may run certain actions like take() and first() locally. */
private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false)

/** Broadcast the serialized tasks only when they are bigger than it */
private val broadcastTaskMinSize =
sc.getConf.getInt("spark.scheduler.broadcastTaskMinSizeKB", 8) * 1024

/** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)

Expand Down Expand Up @@ -820,22 +823,29 @@ class DAGScheduler(
listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))

// TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
// Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast
// We serialize the task first, then broadcast it if the number of serialized bytes is
// larger than broadcastTaskMinSize, used to dispatch tasks to executors. Note that we broadcast
// the serialized copy of the RDD and for each task we will deserialize it, which means each
// task gets a different copy of the RDD. This provides stronger isolation between tasks that
// might modify state of objects referenced in their closures. This is necessary in Hadoop
// where the JobConf/Configuration object is not thread-safe.
var taskBinary: Broadcast[Array[Byte]] = null
var taskBinary: Array[Byte] = null
try {
// For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
// For ResultTask, serialize and broadcast (rdd, func).
val taskBinaryBytes: Array[Byte] =
taskBinary =
if (stage.isShuffleMap) {
closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array()
} else {
closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array()
}
taskBinary = sc.broadcast(taskBinaryBytes)
if (taskBinary.size > broadcastTaskMinSize) {
logInfo(s"Create broadcast for taskBinary: ${taskBinary.size} > $broadcastTaskMinSize")
val broadcasted = sc.broadcast(taskBinary)
// use stage to track the life cycle of broadcast
stage.broadcastedTaskBinary = broadcasted
taskBinary = closureSerializer.serialize(broadcasted: AnyRef).array()
}
} catch {
// In the case of a failure during serialization, abort the stage.
case e: NotSerializableException =>
Expand Down
14 changes: 10 additions & 4 deletions core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD
* See [[Task]] for more information.
*
* @param stageId id of the stage this task belongs to
* @param taskBinary broadcasted version of the serialized RDD and the function to apply on each
* @param taskBinary serialized RDD and the function to apply on each
* partition of the given RDD. Once deserialized, the type should be
* (RDD[T], (TaskContext, Iterator[T]) => U).
* @param partition partition of the RDD this task is associated with
Expand All @@ -41,7 +41,7 @@ import org.apache.spark.rdd.RDD
*/
private[spark] class ResultTask[T, U](
stageId: Int,
taskBinary: Broadcast[Array[Byte]],
taskBinary: Array[Byte],
partition: Partition,
@transient locs: Seq[TaskLocation],
val outputId: Int)
Expand All @@ -54,8 +54,14 @@ private[spark] class ResultTask[T, U](
override def runTask(context: TaskContext): U = {
// Deserialize the RDD and the func using the broadcast variables.
val ser = SparkEnv.get.closureSerializer.newInstance()
val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
val obj = ser.deserialize[AnyRef](ByteBuffer.wrap(taskBinary),
Thread.currentThread.getContextClassLoader)
val (rdd, func) = obj match {
case (rdd: RDD[T], func: ((TaskContext, Iterator[T]) => U)) => (rdd, func)
case b: Broadcast[Array[Byte]] =>
ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(b.value), Thread.currentThread.getContextClassLoader)
}

metrics = Some(context.taskMetrics)
func(context, rdd.iterator(partition, context))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ import org.apache.spark.shuffle.ShuffleWriter
* See [[org.apache.spark.scheduler.Task]] for more information.
*
* @param stageId id of the stage this task belongs to
* @param taskBinary broadcast version of of the RDD and the ShuffleDependency. Once deserialized,
* @param taskBinary serialized RDD and the ShuffleDependency. Once deserialized,
* the type should be (RDD[_], ShuffleDependency[_, _, _]).
* @param partition partition of the RDD this task is associated with
* @param locs preferred task execution locations for locality scheduling
*/
private[spark] class ShuffleMapTask(
stageId: Int,
taskBinary: Broadcast[Array[Byte]],
taskBinary: Array[Byte],
partition: Partition,
@transient private var locs: Seq[TaskLocation])
extends Task[MapStatus](stageId, partition.index) with Logging {
Expand All @@ -57,8 +57,14 @@ private[spark] class ShuffleMapTask(
override def runTask(context: TaskContext): MapStatus = {
// Deserialize the RDD using the broadcast variable.
val ser = SparkEnv.get.closureSerializer.newInstance()
val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
val obj = ser.deserialize[AnyRef](ByteBuffer.wrap(taskBinary),
Thread.currentThread.getContextClassLoader)
val (rdd, dep) = obj match {
case (rdd: RDD[_], dep: ShuffleDependency[_, _, _]) => (rdd, dep)
case b: Broadcast[Array[Byte]] =>
ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
ByteBuffer.wrap(b.value), Thread.currentThread.getContextClassLoader)
}

metrics = Some(context.taskMetrics)
var writer: ShuffleWriter[Any, Any] = null
Expand Down
7 changes: 7 additions & 0 deletions core/src/main/scala/org/apache/spark/scheduler/Stage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.scheduler
import scala.collection.mutable.HashSet

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.CallSite
Expand Down Expand Up @@ -69,6 +70,12 @@ private[spark] class Stage(
var resultOfJob: Option[ActiveJob] = None
var pendingTasks = new HashSet[Task[_]]

/**
* This is used to track the life cycle of broadcast,
* then it can be release by GC once the stage is released
*/
var broadcastedTaskBinary: Broadcast[Array[Byte]] = _

private var nextAttemptId = 0

val name = callSite.shortForm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val task = new ResultTask[String, String](
0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
0, closureSerializer.serialize((rdd, func)).array, rdd.partitions(0), Seq(), 0)
intercept[RuntimeException] {
task.run(0)
}
Expand Down