Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -905,22 +905,29 @@ class DAGScheduler(
return
}

val tasks: Seq[Task[_]] = stage match {
case stage: ShuffleMapStage =>
partitionsToCompute.map { id =>
val locs = getPreferredLocs(stage.rdd, id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, taskBinary, part, locs)
}
val tasks: Seq[Task[_]] = try {
stage match {
case stage: ShuffleMapStage =>
partitionsToCompute.map { id =>
val locs = getPreferredLocs(stage.rdd, id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, taskBinary, part, locs)
}

case stage: ResultStage =>
val job = stage.resultOfJob.get
partitionsToCompute.map { id =>
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = getPreferredLocs(stage.rdd, p)
new ResultTask(stage.id, taskBinary, part, locs, id)
}
case stage: ResultStage =>
val job = stage.resultOfJob.get
partitionsToCompute.map { id =>
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = getPreferredLocs(stage.rdd, p)
new ResultTask(stage.id, taskBinary, part, locs, id)
}
}
} catch {
case NonFatal(e) =>
abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}")
runningStages -= stage
return
}

if (tasks.size > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,37 @@ class DAGSchedulerSuite
assert(sc.parallelize(1 to 10, 2).first() === 1)
}

test("getPartitions exceptions should not crash DAGScheduler and SparkContext (SPARK-8606)") {
val e1 = intercept[DAGSchedulerSuiteDummyException] {
val rdd = new MyRDD(sc, 2, Nil) {
override def getPartitions: Array[Partition] = {
throw new DAGSchedulerSuiteDummyException
}
}
rdd.reduceByKey(_ + _, 1).count()
}

// Make sure we can still run local commands as well as cluster commands.
assert(sc.parallelize(1 to 10, 2).count() === 10)
assert(sc.parallelize(1 to 10, 2).first() === 1)
}

test("getPreferredLocations errors should not crash DAGScheduler and SparkContext (SPARK-8606)") {
val e1 = intercept[SparkException] {
val rdd = new MyRDD(sc, 2, Nil) {
override def getPreferredLocations(split: Partition): Seq[String] = {
throw new DAGSchedulerSuiteDummyException
}
}
rdd.count()
}
assert(e1.getMessage.contains(classOf[DAGSchedulerSuiteDummyException].getName))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once #7014 gets merged we can actually check the causing exception (just a random comment since I just looked at that PR, this is fine as is for now)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to merge this PR now since we're going to need this string comparison in order for the regression test to run in backport branches.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah, that is fine with me -- sorry I didn't mean to imply this should wait on the other PR. I was really just making a random observation, sorry if I confused things.


// Make sure we can still run local commands as well as cluster commands.
assert(sc.parallelize(1 to 10, 2).count() === 10)
assert(sc.parallelize(1 to 10, 2).first() === 1)
}

test("accumulator not calculated for resubmitted result stage") {
// just for register
val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam)
Expand Down