Skip to content

Commit b93c37f

Browse files
committed
[SPARK-19276][CORE] Fetch Failure handling robust to user error handling
Fault-tolerance in spark requires special handling of shuffle fetch failures. The Executor would catch FetchFailedException and send a special msg back to the driver. However, intervening user code could intercept that exception, and wrap it with something else. This even happens in SparkSQL. So rather than checking the exception directly, we'll store the fetch failure directly in the TaskContext, where users can't touch it. This includes a test case which failed before the fix.
1 parent d06172b commit b93c37f

File tree

6 files changed

+167
-26
lines changed

6 files changed

+167
-26
lines changed

core/src/main/scala/org/apache/spark/TaskContext.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi
2424
import org.apache.spark.executor.TaskMetrics
2525
import org.apache.spark.memory.TaskMemoryManager
2626
import org.apache.spark.metrics.source.Source
27+
import org.apache.spark.shuffle.FetchFailedException
2728
import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener}
2829

2930

@@ -190,4 +191,10 @@ abstract class TaskContext extends Serializable {
190191
*/
191192
private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit
192193

194+
/**
195+
* Record that this task has failed due to a fetch failure from a remote host. This allows
196+
* fetch-failure handling to get triggered by the driver, regardless of intervening user-code.
197+
*/
198+
private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit
199+
193200
}

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging
2626
import org.apache.spark.memory.TaskMemoryManager
2727
import org.apache.spark.metrics.MetricsSystem
2828
import org.apache.spark.metrics.source.Source
29+
import org.apache.spark.shuffle.FetchFailedException
2930
import org.apache.spark.util._
3031

3132
private[spark] class TaskContextImpl(
@@ -56,6 +57,8 @@ private[spark] class TaskContextImpl(
5657
// Whether the task has failed.
5758
@volatile private var failed: Boolean = false
5859

60+
var fetchFailed: Option[FetchFailedException] = None
61+
5962
override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
6063
onCompleteCallbacks += listener
6164
this
@@ -126,4 +129,8 @@ private[spark] class TaskContextImpl(
126129
taskMetrics.registerAccumulator(a)
127130
}
128131

132+
private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = {
133+
this.fetchFailed = Some(fetchFailed)
134+
}
135+
129136
}

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ private[spark] class Executor(
148148

149149
startDriverHeartbeater()
150150

151+
private[executor] def numRunningTasks: Int = runningTasks.size()
152+
151153
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
152154
val tr = new TaskRunner(context, taskDescription)
153155
runningTasks.put(taskDescription.taskId, tr)
@@ -340,6 +342,14 @@ private[spark] class Executor(
340342
}
341343
}
342344
}
345+
task.context.fetchFailed.foreach { fetchFailure =>
346+
// uh-oh. it appears the user code has caught the fetch-failure without throwing any
347+
// other exceptions. Its *possible* this is what the user meant to do (though highly
348+
// unlikely). So we will log an error and keep going.
349+
logError(s"TID ${taskId} completed successfully though internally it encountered " +
350+
s"unrecoverable fetch failures! Most likely this means user code is incorrectly " +
351+
s"swallowing Spark's internal exceptions", fetchFailure)
352+
}
343353
val taskFinish = System.currentTimeMillis()
344354
val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
345355
threadMXBean.getCurrentThreadCpuTime
@@ -405,6 +415,13 @@ private[spark] class Executor(
405415
setTaskFinishedAndClearInterruptStatus()
406416
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
407417

418+
case t: Throwable if task.context.fetchFailed.isDefined =>
419+
// tbere was a fetch failure in the task, but some user code wrapped that exception
420+
// and threw something else. Regardless, we treat it as a fetch failure.
421+
val reason = task.context.fetchFailed.get.toTaskFailedReason
422+
setTaskFinishedAndClearInterruptStatus()
423+
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
424+
408425
case _: TaskKilledException =>
409426
logInfo(s"Executor killed $taskName (TID $taskId)")
410427
setTaskFinishedAndClearInterruptStatus()

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,15 @@
1717

1818
package org.apache.spark.scheduler
1919

20-
import java.io.{DataInputStream, DataOutputStream}
2120
import java.nio.ByteBuffer
2221
import java.util.Properties
2322

24-
import scala.collection.mutable
25-
import scala.collection.mutable.HashMap
26-
2723
import org.apache.spark._
2824
import org.apache.spark.executor.TaskMetrics
2925
import org.apache.spark.internal.config.APP_CALLER_CONTEXT
3026
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
3127
import org.apache.spark.metrics.MetricsSystem
32-
import org.apache.spark.serializer.SerializerInstance
28+
import org.apache.spark.shuffle.FetchFailedException
3329
import org.apache.spark.util._
3430

3531
/**
@@ -137,6 +133,8 @@ private[spark] abstract class Task[T](
137133
memoryManager.synchronized { memoryManager.notifyAll() }
138134
}
139135
} finally {
136+
// though we unset the ThreadLocal here, the context itself is still queried directly
137+
// in the TaskRunner to check for FetchFailedExceptions
140138
TaskContext.unset()
141139
}
142140
}
@@ -156,7 +154,7 @@ private[spark] abstract class Task[T](
156154
var epoch: Long = -1
157155

158156
// Task context, to be initialized in run().
159-
@transient protected var context: TaskContextImpl = _
157+
@transient var context: TaskContextImpl = _
160158

161159
// The actual Thread on which the task is running, if any. Initialized in run().
162160
@volatile @transient private var taskThread: Thread = _

core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.shuffle
1919

20-
import org.apache.spark.{FetchFailed, TaskFailedReason}
20+
import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason}
2121
import org.apache.spark.storage.BlockManagerId
2222
import org.apache.spark.util.Utils
2323

@@ -45,6 +45,12 @@ private[spark] class FetchFailedException(
4545
this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause)
4646
}
4747

48+
// SPARK-19267. We set the fetch failure in the task context, so that even if there is user-code
49+
// which intercepts this exception (possibly wrapping it), the Executor can still tell there was
50+
// a fetch failure, and send the correct error msg back to the driver. The TaskContext won't be
51+
// defined if this is run on the driver (just in test cases) -- we can safely ignore then.
52+
Option(TaskContext.get()).map(_.setFetchFailed(this))
53+
4854
def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
4955
Utils.exceptionString(this))
5056
}

core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala

Lines changed: 125 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,34 @@ import java.util.concurrent.CountDownLatch
2323

2424
import scala.collection.mutable.Map
2525

26-
import org.mockito.Matchers._
27-
import org.mockito.Mockito.{mock, when}
26+
import org.mockito.ArgumentCaptor
27+
import org.mockito.Matchers.{any, eq => meq}
28+
import org.mockito.Mockito.{inOrder, when}
2829
import org.mockito.invocation.InvocationOnMock
2930
import org.mockito.stubbing.Answer
31+
import org.scalatest.mock.MockitoSugar
3032

3133
import org.apache.spark._
3234
import org.apache.spark.TaskState.TaskState
3335
import org.apache.spark.memory.MemoryManager
3436
import org.apache.spark.metrics.MetricsSystem
37+
import org.apache.spark.rdd.RDD
3538
import org.apache.spark.rpc.RpcEnv
36-
import org.apache.spark.scheduler.{FakeTask, TaskDescription}
39+
import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription}
3740
import org.apache.spark.serializer.JavaSerializer
41+
import org.apache.spark.shuffle.FetchFailedException
42+
import org.apache.spark.storage.BlockManagerId
3843

39-
class ExecutorSuite extends SparkFunSuite {
44+
class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar {
4045

4146
test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") {
4247
// mock some objects to make Executor.launchTask() happy
4348
val conf = new SparkConf
4449
val serializer = new JavaSerializer(conf)
45-
val mockEnv = mock(classOf[SparkEnv])
46-
val mockRpcEnv = mock(classOf[RpcEnv])
47-
val mockMetricsSystem = mock(classOf[MetricsSystem])
48-
val mockMemoryManager = mock(classOf[MemoryManager])
50+
val mockEnv = mock[SparkEnv]
51+
val mockRpcEnv = mock[RpcEnv]
52+
val mockMetricsSystem = mock[MetricsSystem]
53+
val mockMemoryManager = mock[MemoryManager]
4954
when(mockEnv.conf).thenReturn(conf)
5055
when(mockEnv.serializer).thenReturn(serializer)
5156
when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
@@ -55,16 +60,7 @@ class ExecutorSuite extends SparkFunSuite {
5560
val fakeTaskMetrics = serializer.newInstance().serialize(TaskMetrics.registered).array()
5661
val serializedTask = serializer.newInstance().serialize(
5762
new FakeTask(0, 0, Nil, fakeTaskMetrics))
58-
val taskDescription = new TaskDescription(
59-
taskId = 0,
60-
attemptNumber = 0,
61-
executorId = "",
62-
name = "",
63-
index = 0,
64-
addedFiles = Map[String, Long](),
65-
addedJars = Map[String, Long](),
66-
properties = new Properties,
67-
serializedTask)
63+
val taskDescription = fakeTaskDescription(serializedTask)
6864

6965
// we use latches to force the program to run in this order:
7066
// +-----------------------------+---------------------------------------+
@@ -86,7 +82,7 @@ class ExecutorSuite extends SparkFunSuite {
8682

8783
val executorSuiteHelper = new ExecutorSuiteHelper
8884

89-
val mockExecutorBackend = mock(classOf[ExecutorBackend])
85+
val mockExecutorBackend = mock[ExecutorBackend]
9086
when(mockExecutorBackend.statusUpdate(any(), any(), any()))
9187
.thenAnswer(new Answer[Unit] {
9288
var firstTime = true
@@ -133,6 +129,116 @@ class ExecutorSuite extends SparkFunSuite {
133129
}
134130
}
135131
}
132+
133+
test("SPARK-19276: Handle Fetch Failed for all intervening user code") {
134+
val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
135+
val sc = new SparkContext(conf)
136+
137+
val serializer = SparkEnv.get.closureSerializer.newInstance()
138+
val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size
139+
val inputRDD = new FakeShuffleRDD(sc)
140+
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD)
141+
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
142+
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
143+
val task = new ResultTask(
144+
stageId = 1,
145+
stageAttemptId = 0,
146+
taskBinary = taskBinary,
147+
partition = secondRDD.partitions(0),
148+
locs = Seq(),
149+
outputId = 0,
150+
localProperties = new Properties(),
151+
serializedTaskMetrics = serializedTaskMetrics
152+
)
153+
154+
val serTask = serializer.serialize(task)
155+
val taskDescription = fakeTaskDescription(serTask)
156+
157+
158+
val mockBackend = mock[ExecutorBackend]
159+
var executor: Executor = null
160+
try {
161+
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true)
162+
executor.launchTask(mockBackend, taskDescription)
163+
val startTime = System.currentTimeMillis()
164+
val maxTime = startTime + 5000
165+
while (executor.numRunningTasks > 0 && System.currentTimeMillis() < maxTime) {
166+
Thread.sleep(10)
167+
}
168+
val orderedMock = inOrder(mockBackend)
169+
val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
170+
orderedMock.verify(mockBackend)
171+
.statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture())
172+
orderedMock.verify(mockBackend)
173+
.statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture())
174+
// first statusUpdate for RUNNING has empty data
175+
assert(statusCaptor.getAllValues().get(0).remaining() === 0)
176+
// second update is more interesting
177+
val failureData = statusCaptor.getAllValues.get(1)
178+
val failReason = serializer.deserialize[TaskFailedReason](failureData)
179+
assert(failReason.isInstanceOf[FetchFailed])
180+
} finally {
181+
if (executor != null) {
182+
executor.stop()
183+
}
184+
}
185+
}
186+
187+
private def fakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = {
188+
new TaskDescription(
189+
taskId = 0,
190+
attemptNumber = 0,
191+
executorId = "",
192+
name = "",
193+
index = 0,
194+
addedFiles = Map[String, Long](),
195+
addedJars = Map[String, Long](),
196+
properties = new Properties,
197+
serializedTask)
198+
}
199+
200+
}
201+
202+
class FakeShuffleRDD(sc: SparkContext) extends RDD[Int](sc, Nil) {
203+
override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
204+
new Iterator[Int] {
205+
override def hasNext: Boolean = true
206+
override def next(): Int = {
207+
throw new FetchFailedException(
208+
bmAddress = BlockManagerId("1", "hostA", 1234),
209+
shuffleId = 0,
210+
mapId = 0,
211+
reduceId = 0,
212+
message = "fake fetch failure"
213+
)
214+
}
215+
}
216+
}
217+
override protected def getPartitions: Array[Partition] = {
218+
Array(new SimplePartition)
219+
}
220+
}
221+
222+
class SimplePartition extends Partition {
223+
override def index: Int = 0
224+
}
225+
226+
class FetchFailureHidingRDD(
227+
sc: SparkContext,
228+
val input: FakeShuffleRDD) extends RDD[Int](input) {
229+
override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
230+
val inItr = input.compute(split, context)
231+
try {
232+
Iterator(inItr.size)
233+
} catch {
234+
case t: Throwable =>
235+
throw new RuntimeException("User Exception that hides the original exception", t)
236+
}
237+
}
238+
239+
override protected def getPartitions: Array[Partition] = {
240+
Array(new SimplePartition)
241+
}
136242
}
137243

138244
// Helps to test("SPARK-15963")

0 commit comments

Comments
 (0)