Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 43 additions & 39 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -378,59 +378,63 @@ class DAGScheduler(
* the provided firstJobId.
*/
private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {
val parents = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting
val waitingForVisit = new Stack[RDD[_]]
def visit(r: RDD[_]) {
if (!visited(r)) {
visited += r
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of partitions is unknown
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
parents += getShuffleMapStage(shufDep, firstJobId)
case _ =>
waitingForVisit.push(dep.rdd)
}
}
}
}
waitingForVisit.push(rdd)
while (waitingForVisit.nonEmpty) {
visit(waitingForVisit.pop())
}
parents.toList
getShuffleDependencies(rdd).map { shuffleDep =>
getShuffleMapStage(shuffleDep, firstJobId)
}.toList
}

/** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */
private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = {
val parents = new Stack[ShuffleDependency[_, _, _]]
val ancestors = new Stack[ShuffleDependency[_, _, _]]
val visited = new HashSet[RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment is getting lost ... but I think that is OK, I have no idea what it is referring to here. Its also in newOrUsedShuffleStage, which is probably the only place it belongs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I decided that it was probably an accident that it's here -- as you said, I think it only belongs in newOrUsedShuffleStage.

// caused by recursively visiting
val waitingForVisit = new Stack[RDD[_]]
def visit(r: RDD[_]) {
if (!visited(r)) {
visited += r
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
if (!shuffleToMapStage.contains(shufDep.shuffleId)) {
parents.push(shufDep)
}
case _ =>
}
waitingForVisit.push(dep.rdd)
waitingForVisit.push(rdd)
while (waitingForVisit.nonEmpty) {
val toVisit = waitingForVisit.pop()
if (!visited(toVisit)) {
visited += toVisit
getShuffleDependencies(toVisit).foreach { shuffleDep =>
if (!shuffleToMapStage.contains(shuffleDep.shuffleId)) {
ancestors.push(shuffleDep)
waitingForVisit.push(shuffleDep.rdd)
} // Otherwise, the dependency and it's ancestors have already been registered.
}
}
}
ancestors
}

/**
* Returns shuffle dependencies that are immediate parents of the given RDD.
*
* This function will not return more distant ancestors. For example, if C has a shuffle
* dependency on B which has a shuffle dependency on A:
*
* A <-- B <-- C
*
* calling this function with rdd C will only return the B <-- C dependency.
*
* This function is scheduler-visible for the purpose of unit testing.
*/
private[scheduler] def getShuffleDependencies(
rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = {
val parents = new HashSet[ShuffleDependency[_, _, _]]
val visited = new HashSet[RDD[_]]
val waitingForVisit = new Stack[RDD[_]]
waitingForVisit.push(rdd)
while (waitingForVisit.nonEmpty) {
visit(waitingForVisit.pop())
val toVisit = waitingForVisit.pop()
if (!visited(toVisit)) {
visited += toVisit
toVisit.dependencies.foreach {
case shuffleDep: ShuffleDependency[_, _, _] =>
parents += shuffleDep
case dependency =>
waitingForVisit.push(dependency.rdd)
}
}
}
parents
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2025,6 +2025,37 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
assertDataStructuresEmpty()
}

/**
* Checks the DAGScheduler's internal logic for traversing a RDD DAG by making sure that
* getShuffleDependencies correctly returns the direct shuffle dependencies of a particular
* RDD. The test creates the following RDD graph (where n denotes a narrow dependency and s
* denotes a shuffle dependency):
*
* A <------------s---------,
* \
* B <--s-- C <--s-- D <--n---`-- E
*
* Here, the direct shuffle dependency of C is just the shuffle dependency on B. The direct
* shuffle dependencies of E are the shuffle dependency on A and the shuffle dependency on C.
*/
test("getShuffleDependencies correctly returns only direct shuffle parents") {
val rddA = new MyRDD(sc, 2, Nil)
val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(1))
val rddB = new MyRDD(sc, 2, Nil)
val shuffleDepB = new ShuffleDependency(rddB, new HashPartitioner(1))
val rddC = new MyRDD(sc, 1, List(shuffleDepB))
val shuffleDepC = new ShuffleDependency(rddC, new HashPartitioner(1))
val rddD = new MyRDD(sc, 1, List(shuffleDepC))
val narrowDepD = new OneToOneDependency(rddD)
val rddE = new MyRDD(sc, 1, List(shuffleDepA, narrowDepD), tracker = mapOutputTracker)

assert(scheduler.getShuffleDependencies(rddA) === Set())
assert(scheduler.getShuffleDependencies(rddB) === Set())
assert(scheduler.getShuffleDependencies(rddC) === Set(shuffleDepB))
assert(scheduler.getShuffleDependencies(rddD) === Set(shuffleDepC))
assert(scheduler.getShuffleDependencies(rddE) === Set(shuffleDepA, shuffleDepC))
}

/**
* Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
* Note that this checks only the host and not the executor ID.
Expand Down