Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.collection.mutable
import scala.concurrent.ExecutionContext
import scala.util.control.NonFatal

import org.apache.spark.SparkException
import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -513,14 +514,13 @@ case class AdaptiveSparkPlanExec(
val newPlan = e.withNewChildren(Seq(result.newPlan)).asInstanceOf[Exchange]
// Create a query stage only when all the child query stages are ready.
if (result.allChildStagesMaterialized) {
var newStage = newQueryStage(newPlan)
assert(newStage.isInstanceOf[ReusableQueryStageExec])
var newStage = newQueryStage(newPlan).asInstanceOf[ExchangeQueryStageExec]
if (conf.exchangeReuseEnabled) {
// Check the `stageCache` again for reuse. If a match is found, ditch the new stage
// and reuse the existing stage found in the `stageCache`, otherwise update the
// `stageCache` with the new stage.
val queryStage = context.stageCache.getOrElseUpdate(
newStage.plan.canonicalized, newStage.asInstanceOf[ReusableQueryStageExec])
newStage.plan.canonicalized, newStage)
if (queryStage.ne(newStage)) {
newStage = reuseQueryStage(queryStage, e)
}
Expand Down Expand Up @@ -561,49 +561,43 @@ case class AdaptiveSparkPlanExec(
}

private def newQueryStage(plan: SparkPlan): QueryStageExec = {
val optimizedPlan = plan match {
case e: Exchange =>
e.withNewChildren(Seq(optimizeQueryStage(e.child, isFinalStage = false)))
case _ => plan
}
val newPlan = applyPhysicalRules(
optimizedPlan,
postStageCreationRules(outputsColumnar = plan.supportsColumnar),
Some((planChangeLogger, "AQE Post Stage Creation")))
val queryStage = plan match {
case s: ShuffleExchangeLike =>
val optimizedPlan = optimizeQueryStage(s.child, isFinalStage = false)
val newShuffle = applyPhysicalRules(
s.withNewChildren(Seq(optimizedPlan)),
postStageCreationRules(outputsColumnar = s.supportsColumnar),
Some((planChangeLogger, "AQE Post Stage Creation")))
if (!newShuffle.isInstanceOf[ShuffleExchangeLike]) {
throw new IllegalStateException(
if (!newPlan.isInstanceOf[ShuffleExchangeLike]) {
throw SparkException.internalError(
"Custom columnar rules cannot transform shuffle node to something else.")
}
ShuffleQueryStageExec(currentStageId, newShuffle, s.canonicalized)
ShuffleQueryStageExec(currentStageId, newPlan, s.canonicalized)
case b: BroadcastExchangeLike =>
val optimizedPlan = optimizeQueryStage(b.child, isFinalStage = false)
val newBroadcast = applyPhysicalRules(
b.withNewChildren(Seq(optimizedPlan)),
postStageCreationRules(outputsColumnar = b.supportsColumnar),
Some((planChangeLogger, "AQE Post Stage Creation")))
if (!newBroadcast.isInstanceOf[BroadcastExchangeLike]) {
throw new IllegalStateException(
if (!newPlan.isInstanceOf[BroadcastExchangeLike]) {
throw SparkException.internalError(
"Custom columnar rules cannot transform broadcast node to something else.")
}
BroadcastQueryStageExec(currentStageId, newBroadcast, b.canonicalized)
BroadcastQueryStageExec(currentStageId, newPlan, b.canonicalized)
case i: InMemoryTableScanExec =>
val newInMemoryTableScan = applyPhysicalRules(
i,
postStageCreationRules(outputsColumnar = i.supportsColumnar),
Some((planChangeLogger, "AQE Post Stage Creation")))
if (!newInMemoryTableScan.isInstanceOf[InMemoryTableScanExec]) {
throw new IllegalStateException("Custom columnar rules cannot transform " +
if (!newPlan.isInstanceOf[InMemoryTableScanExec]) {
throw SparkException.internalError("Custom columnar rules cannot transform " +
"`InMemoryTableScanExec` node to something else.")
}
TableCacheQueryStageExec(
currentStageId, newInMemoryTableScan.asInstanceOf[InMemoryTableScanExec])
TableCacheQueryStageExec(currentStageId, newPlan.asInstanceOf[InMemoryTableScanExec])
}
currentStageId += 1
setLogicalLinkForNewQueryStage(queryStage, plan)
queryStage
}

private def reuseQueryStage(
existing: ReusableQueryStageExec,
exchange: Exchange): QueryStageExec = {
existing: ExchangeQueryStageExec,
exchange: Exchange): ExchangeQueryStageExec = {
val queryStage = existing.newReuseInstance(currentStageId, exchange.output)
currentStageId += 1
setLogicalLinkForNewQueryStage(queryStage, exchange)
Expand Down Expand Up @@ -770,7 +764,7 @@ case class AdaptiveSparkPlanExec(
currentPhysicalPlan.foreach {
// earlyFailedStage is the stage which failed before calling doMaterialize,
// so we should avoid calling cancel on it to re-trigger the failure again.
case s: QueryStageExec if !earlyFailedStage.contains(s.id) =>
case s: ExchangeQueryStageExec if !earlyFailedStage.contains(s.id) =>
try {
s.cancel()
} catch {
Expand Down Expand Up @@ -847,8 +841,8 @@ case class AdaptiveExecutionContext(session: SparkSession, qe: QueryExecution) {
/**
* The exchange-reuse map shared across the entire query, including sub-queries.
*/
val stageCache: TrieMap[SparkPlan, ReusableQueryStageExec] =
new TrieMap[SparkPlan, ReusableQueryStageExec]()
val stageCache: TrieMap[SparkPlan, ExchangeQueryStageExec] =
new TrieMap[SparkPlan, ExchangeQueryStageExec]()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ abstract class QueryStageExec extends LeafExecNode {
*/
val plan: SparkPlan

/**
* Cancel the stage materialization if in progress; otherwise do nothing.
*/
def cancel(): Unit

/**
* Materialize this query stage, to prepare for the execution, like submitting map stages,
* broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this
Expand Down Expand Up @@ -142,13 +137,18 @@ abstract class QueryStageExec extends LeafExecNode {
}

/**
* There are 2 kinds of reusable query stages:
* There are 2 kinds of exchange query stages:
* 1. Shuffle query stage. This stage materializes its output to shuffle files, and Spark launches
* another job to execute the further operators.
* 2. Broadcast query stage. This stage materializes its output to an array in driver JVM. Spark
* broadcasts the array before executing the further operators.
*/
abstract class ReusableQueryStageExec extends QueryStageExec {
abstract class ExchangeQueryStageExec extends QueryStageExec {

/**
* Cancel the stage materialization if in progress; otherwise do nothing.
*/
def cancel(): Unit

/**
* The canonicalized plan before applying query stage optimizer rules.
Expand All @@ -157,7 +157,7 @@ abstract class ReusableQueryStageExec extends QueryStageExec {

override def doCanonicalize(): SparkPlan = _canonicalized

def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec
def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec
}

/**
Expand All @@ -170,7 +170,7 @@ abstract class ReusableQueryStageExec extends QueryStageExec {
case class ShuffleQueryStageExec(
override val id: Int,
override val plan: SparkPlan,
override val _canonicalized: SparkPlan) extends ReusableQueryStageExec {
override val _canonicalized: SparkPlan) extends ExchangeQueryStageExec {

@transient val shuffle = plan match {
case s: ShuffleExchangeLike => s
Expand All @@ -183,7 +183,8 @@ case class ShuffleQueryStageExec(

override protected def doMaterialize(): Future[Any] = shuffleFuture

override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = {
override def newReuseInstance(
newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec = {
val reuse = ShuffleQueryStageExec(
newStageId,
ReusedExchangeExec(newOutput, shuffle),
Expand Down Expand Up @@ -221,7 +222,7 @@ case class ShuffleQueryStageExec(
case class BroadcastQueryStageExec(
override val id: Int,
override val plan: SparkPlan,
override val _canonicalized: SparkPlan) extends ReusableQueryStageExec {
override val _canonicalized: SparkPlan) extends ExchangeQueryStageExec {

@transient val broadcast = plan match {
case b: BroadcastExchangeLike => b
Expand All @@ -234,7 +235,8 @@ case class BroadcastQueryStageExec(
broadcast.submitBroadcastJob
}

override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = {
override def newReuseInstance(
newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec = {
val reuse = BroadcastQueryStageExec(
newStageId,
ReusedExchangeExec(newOutput, broadcast),
Expand Down Expand Up @@ -285,11 +287,5 @@ case class TableCacheQueryStageExec(

override def isMaterialized: Boolean = super.isMaterialized || inMemoryTableScan.isMaterialized

override def cancel(): Unit = {
if (!isMaterialized) {
logDebug(s"Skip canceling the table cache stage: $id")
}
}

override def getRuntimeStatistics: Statistics = inMemoryTableScan.relation.computeStats()
}