@@ -23,29 +23,34 @@ import java.util.concurrent.CountDownLatch
2323
2424import scala .collection .mutable .Map
2525
26- import org .mockito .Matchers ._
27- import org .mockito .Mockito .{mock , when }
26+ import org .mockito .ArgumentCaptor
27+ import org .mockito .Matchers .{any , eq => meq }
28+ import org .mockito .Mockito .{inOrder , when }
2829import org .mockito .invocation .InvocationOnMock
2930import org .mockito .stubbing .Answer
31+ import org .scalatest .mock .MockitoSugar
3032
3133import org .apache .spark ._
3234import org .apache .spark .TaskState .TaskState
3335import org .apache .spark .memory .MemoryManager
3436import org .apache .spark .metrics .MetricsSystem
37+ import org .apache .spark .rdd .RDD
3538import org .apache .spark .rpc .RpcEnv
36- import org .apache .spark .scheduler .{FakeTask , TaskDescription }
39+ import org .apache .spark .scheduler .{FakeTask , ResultTask , TaskDescription }
3740import org .apache .spark .serializer .JavaSerializer
41+ import org .apache .spark .shuffle .FetchFailedException
42+ import org .apache .spark .storage .BlockManagerId
3843
39- class ExecutorSuite extends SparkFunSuite {
44+ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar {
4045
4146 test(" SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner" ) {
4247 // mock some objects to make Executor.launchTask() happy
4348 val conf = new SparkConf
4449 val serializer = new JavaSerializer (conf)
45- val mockEnv = mock( classOf [SparkEnv ])
46- val mockRpcEnv = mock( classOf [RpcEnv ])
47- val mockMetricsSystem = mock( classOf [MetricsSystem ])
48- val mockMemoryManager = mock( classOf [MemoryManager ])
50+ val mockEnv = mock[SparkEnv ]
51+ val mockRpcEnv = mock[RpcEnv ]
52+ val mockMetricsSystem = mock[MetricsSystem ]
53+ val mockMemoryManager = mock[MemoryManager ]
4954 when(mockEnv.conf).thenReturn(conf)
5055 when(mockEnv.serializer).thenReturn(serializer)
5156 when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
@@ -55,16 +60,7 @@ class ExecutorSuite extends SparkFunSuite {
5560 val fakeTaskMetrics = serializer.newInstance().serialize(TaskMetrics .registered).array()
5661 val serializedTask = serializer.newInstance().serialize(
5762 new FakeTask (0 , 0 , Nil , fakeTaskMetrics))
58- val taskDescription = new TaskDescription (
59- taskId = 0 ,
60- attemptNumber = 0 ,
61- executorId = " " ,
62- name = " " ,
63- index = 0 ,
64- addedFiles = Map [String , Long ](),
65- addedJars = Map [String , Long ](),
66- properties = new Properties ,
67- serializedTask)
63+ val taskDescription = fakeTaskDescription(serializedTask)
6864
6965 // we use latches to force the program to run in this order:
7066 // +-----------------------------+---------------------------------------+
@@ -86,7 +82,7 @@ class ExecutorSuite extends SparkFunSuite {
8682
8783 val executorSuiteHelper = new ExecutorSuiteHelper
8884
89- val mockExecutorBackend = mock( classOf [ExecutorBackend ])
85+ val mockExecutorBackend = mock[ExecutorBackend ]
9086 when(mockExecutorBackend.statusUpdate(any(), any(), any()))
9187 .thenAnswer(new Answer [Unit ] {
9288 var firstTime = true
@@ -133,6 +129,116 @@ class ExecutorSuite extends SparkFunSuite {
133129 }
134130 }
135131 }
132+
133+ test(" SPARK-19276: Handle Fetch Failed for all intervening user code" ) {
134+ val conf = new SparkConf ().setMaster(" local" ).setAppName(" executor suite test" )
135+ val sc = new SparkContext (conf)
136+
137+ val serializer = SparkEnv .get.closureSerializer.newInstance()
138+ val resultFunc = (context : TaskContext , itr : Iterator [Int ]) => itr.size
139+ val inputRDD = new FakeShuffleRDD (sc)
140+ val secondRDD = new FetchFailureHidingRDD (sc, inputRDD)
141+ val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
142+ val serializedTaskMetrics = serializer.serialize(TaskMetrics .registered).array()
143+ val task = new ResultTask (
144+ stageId = 1 ,
145+ stageAttemptId = 0 ,
146+ taskBinary = taskBinary,
147+ partition = secondRDD.partitions(0 ),
148+ locs = Seq (),
149+ outputId = 0 ,
150+ localProperties = new Properties (),
151+ serializedTaskMetrics = serializedTaskMetrics
152+ )
153+
154+ val serTask = serializer.serialize(task)
155+ val taskDescription = fakeTaskDescription(serTask)
156+
157+
158+ val mockBackend = mock[ExecutorBackend ]
159+ var executor : Executor = null
160+ try {
161+ executor = new Executor (" id" , " localhost" , SparkEnv .get, userClassPath = Nil , isLocal = true )
162+ executor.launchTask(mockBackend, taskDescription)
163+ val startTime = System .currentTimeMillis()
164+ val maxTime = startTime + 5000
165+ while (executor.numRunningTasks > 0 && System .currentTimeMillis() < maxTime) {
166+ Thread .sleep(10 )
167+ }
168+ val orderedMock = inOrder(mockBackend)
169+ val statusCaptor = ArgumentCaptor .forClass(classOf [ByteBuffer ])
170+ orderedMock.verify(mockBackend)
171+ .statusUpdate(meq(0L ), meq(TaskState .RUNNING ), statusCaptor.capture())
172+ orderedMock.verify(mockBackend)
173+ .statusUpdate(meq(0L ), meq(TaskState .FAILED ), statusCaptor.capture())
174+ // first statusUpdate for RUNNING has empty data
175+ assert(statusCaptor.getAllValues().get(0 ).remaining() === 0 )
176+ // second update is more interesting
177+ val failureData = statusCaptor.getAllValues.get(1 )
178+ val failReason = serializer.deserialize[TaskFailedReason ](failureData)
179+ assert(failReason.isInstanceOf [FetchFailed ])
180+ } finally {
181+ if (executor != null ) {
182+ executor.stop()
183+ }
184+ }
185+ }
186+
187+ private def fakeTaskDescription (serializedTask : ByteBuffer ): TaskDescription = {
188+ new TaskDescription (
189+ taskId = 0 ,
190+ attemptNumber = 0 ,
191+ executorId = " " ,
192+ name = " " ,
193+ index = 0 ,
194+ addedFiles = Map [String , Long ](),
195+ addedJars = Map [String , Long ](),
196+ properties = new Properties ,
197+ serializedTask)
198+ }
199+
200+ }
201+
202+ class FakeShuffleRDD (sc : SparkContext ) extends RDD [Int ](sc, Nil ) {
203+ override def compute (split : Partition , context : TaskContext ): Iterator [Int ] = {
204+ new Iterator [Int ] {
205+ override def hasNext : Boolean = true
206+ override def next (): Int = {
207+ throw new FetchFailedException (
208+ bmAddress = BlockManagerId (" 1" , " hostA" , 1234 ),
209+ shuffleId = 0 ,
210+ mapId = 0 ,
211+ reduceId = 0 ,
212+ message = " fake fetch failure"
213+ )
214+ }
215+ }
216+ }
217+ override protected def getPartitions : Array [Partition ] = {
218+ Array (new SimplePartition )
219+ }
220+ }
221+
222+ class SimplePartition extends Partition {
223+ override def index : Int = 0
224+ }
225+
226+ class FetchFailureHidingRDD (
227+ sc : SparkContext ,
228+ val input : FakeShuffleRDD ) extends RDD [Int ](input) {
229+ override def compute (split : Partition , context : TaskContext ): Iterator [Int ] = {
230+ val inItr = input.compute(split, context)
231+ try {
232+ Iterator (inItr.size)
233+ } catch {
234+ case t : Throwable =>
235+ throw new RuntimeException (" User Exception that hides the original exception" , t)
236+ }
237+ }
238+
239+ override protected def getPartitions : Array [Partition ] = {
240+ Array (new SimplePartition )
241+ }
136242}
137243
138244// Helps to test("SPARK-15963")
0 commit comments