Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._
import org.apache.spark.sql.execution.bucketing.DisableUnnecessaryBucketedScan
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -220,10 +221,24 @@ case class AdaptiveSparkPlanExec(
}

private def getExecutionId: Option[Long] = {
// If the `QueryExecution` does not match the current execution ID, it means the execution ID
// belongs to another (parent) query, and we should not call update UI in this query.
Option(context.session.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY))
.map(_.toLong).filter(SQLExecution.getQueryExecution(_) eq context.qe)
.map(_.toLong)
}

private lazy val shouldUpdatePlan: Boolean = {
// There are two cases that should not update plan:
// 1. When executing subqueries, we can't update the query plan in the UI as the
// UI doesn't support partial update yet. However, the subquery may have been
// optimized into a different plan and we must let the UI know the SQL metrics
// of the new plan nodes, so that it can track the valid accumulator updates later
// and display SQL metrics correctly.
// 2. If the `QueryExecution` does not match the current execution ID, it means the execution
// ID belongs to another (parent) query, and we should not call update UI in this query.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we mention that this can happen with table cache?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, added

// e.g., a nested `AdaptiveSparkPlanExec` in `InMemoryTableScanExec`.
//
// That means only the root `AdaptiveSparkPlanExec` of the main query that triggers this
// query execution need to do a plan update for the UI.
!isSubquery && getExecutionId.exists(SQLExecution.getQueryExecution(_) eq context.qe)
}

def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity)
Expand Down Expand Up @@ -345,7 +360,7 @@ case class AdaptiveSparkPlanExec(
// Subqueries that don't belong to any query stage of the main query will execute after the
// last UI update in `getFinalPhysicalPlan`, so we need to update UI here again to make sure
// the newly generated nodes of those subqueries are updated.
if (!isSubquery && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) {
if (shouldUpdatePlan && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldUpdatePlan is not required since we already checked it inside onUpdatePlan. The reason leave it here is to fast skip currentPhysicalPlan.exists(_.subqueries.nonEmpty)

getExecutionId.foreach(onUpdatePlan(_, Seq.empty))
}
logOnLevel(s"Final plan:\n$currentPhysicalPlan")
Expand Down Expand Up @@ -499,12 +514,13 @@ case class AdaptiveSparkPlanExec(
// Create a query stage only when all the child query stages are ready.
if (result.allChildStagesMaterialized) {
var newStage = newQueryStage(newPlan)
assert(newStage.isInstanceOf[ReusableQueryStageExec])
if (conf.exchangeReuseEnabled) {
// Check the `stageCache` again for reuse. If a match is found, ditch the new stage
// and reuse the existing stage found in the `stageCache`, otherwise update the
// `stageCache` with the new stage.
val queryStage = context.stageCache.getOrElseUpdate(
newStage.plan.canonicalized, newStage)
newStage.plan.canonicalized, newStage.asInstanceOf[ReusableQueryStageExec])
if (queryStage.ne(newStage)) {
newStage = reuseQueryStage(queryStage, e)
}
Expand All @@ -520,6 +536,14 @@ case class AdaptiveSparkPlanExec(
}
}

case i: InMemoryTableScanExec =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: if the table cache is already materialized (second access of the cache), do we still need to wrap it with TableCacheQueryStage?

Copy link
Contributor Author

@ulysses-you ulysses-you Mar 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TableCacheQueryStage provides a base framework for runtime statistics, so I think wrap it should be more suitable for AQE framework. e.g., mark isRuntime = true in Statistics.

val newStage = newQueryStage(i)
val isMaterialized = newStage.isMaterialized
CreateStageResult(
newPlan = newStage,
allChildStagesMaterialized = isMaterialized,
newStages = if (isMaterialized) Seq.empty else Seq(newStage))

case q: QueryStageExec =>
CreateStageResult(newPlan = q,
allChildStagesMaterialized = q.isMaterialized, newStages = Seq.empty)
Expand All @@ -536,10 +560,10 @@ case class AdaptiveSparkPlanExec(
}
}

private def newQueryStage(e: Exchange): QueryStageExec = {
val optimizedPlan = optimizeQueryStage(e.child, isFinalStage = false)
val queryStage = e match {
private def newQueryStage(plan: SparkPlan): QueryStageExec = {
val queryStage = plan match {
case s: ShuffleExchangeLike =>
val optimizedPlan = optimizeQueryStage(s.child, isFinalStage = false)
val newShuffle = applyPhysicalRules(
s.withNewChildren(Seq(optimizedPlan)),
postStageCreationRules(outputsColumnar = s.supportsColumnar),
Expand All @@ -550,6 +574,7 @@ case class AdaptiveSparkPlanExec(
}
ShuffleQueryStageExec(currentStageId, newShuffle, s.canonicalized)
case b: BroadcastExchangeLike =>
val optimizedPlan = optimizeQueryStage(b.child, isFinalStage = false)
val newBroadcast = applyPhysicalRules(
b.withNewChildren(Seq(optimizedPlan)),
postStageCreationRules(outputsColumnar = b.supportsColumnar),
Expand All @@ -559,13 +584,26 @@ case class AdaptiveSparkPlanExec(
"Custom columnar rules cannot transform broadcast node to something else.")
}
BroadcastQueryStageExec(currentStageId, newBroadcast, b.canonicalized)
case i: InMemoryTableScanExec =>
val newInMemoryTableScan = applyPhysicalRules(
i,
postStageCreationRules(outputsColumnar = i.supportsColumnar),
Some((planChangeLogger, "AQE Post Stage Creation")))
if (!newInMemoryTableScan.isInstanceOf[InMemoryTableScanExec]) {
throw new IllegalStateException("Custom columnar rules cannot transform " +
"`InMemoryTableScanExec` node to something else.")
}
TableCacheQueryStageExec(
currentStageId, newInMemoryTableScan.asInstanceOf[InMemoryTableScanExec])
}
currentStageId += 1
setLogicalLinkForNewQueryStage(queryStage, e)
setLogicalLinkForNewQueryStage(queryStage, plan)
queryStage
}

private def reuseQueryStage(existing: QueryStageExec, exchange: Exchange): QueryStageExec = {
private def reuseQueryStage(
existing: ReusableQueryStageExec,
exchange: Exchange): QueryStageExec = {
val queryStage = existing.newReuseInstance(currentStageId, exchange.output)
currentStageId += 1
setLogicalLinkForNewQueryStage(queryStage, exchange)
Expand Down Expand Up @@ -707,12 +745,7 @@ case class AdaptiveSparkPlanExec(
* Notify the listeners of the physical plan change.
*/
private def onUpdatePlan(executionId: Long, newSubPlans: Seq[SparkPlan]): Unit = {
if (isSubquery) {
// When executing subqueries, we can't update the query plan in the UI as the
// UI doesn't support partial update yet. However, the subquery may have been
// optimized into a different plan and we must let the UI know the SQL metrics
// of the new plan nodes, so that it can track the valid accumulator updates later
// and display SQL metrics correctly.
if (!shouldUpdatePlan) {
val newMetrics = newSubPlans.flatMap { p =>
p.flatMap(_.metrics.values.map(m => SQLPlanMetric(m.name.get, m.id, m.metricType)))
}
Expand Down Expand Up @@ -814,8 +847,8 @@ case class AdaptiveExecutionContext(session: SparkSession, qe: QueryExecution) {
/**
* The exchange-reuse map shared across the entire query, including sub-queries.
*/
val stageCache: TrieMap[SparkPlan, QueryStageExec] =
new TrieMap[SparkPlan, QueryStageExec]()
val stageCache: TrieMap[SparkPlan, ReusableQueryStageExec] =
new TrieMap[SparkPlan, ReusableQueryStageExec]()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
import org.apache.spark.sql.execution.datasources.V1WriteCommand
import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
Expand Down Expand Up @@ -88,12 +89,15 @@ case class InsertAdaptiveSparkPlan(
// - The query may need to add exchanges. It's an overkill to run `EnsureRequirements` here, so
// we just check `SparkPlan.requiredChildDistribution` and see if it's possible that the
// the query needs to add exchanges later.
// - The query contains nested `AdaptiveSparkPlanExec`.
// - The query contains sub-query.
private def shouldApplyAQE(plan: SparkPlan, isSubquery: Boolean): Boolean = {
conf.getConf(SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY) || isSubquery || {
plan.exists {
case _: Exchange => true
case p if !p.requiredChildDistribution.forall(_ == UnspecifiedDistribution) => true
case i: InMemoryTableScanExec
if i.relation.cachedPlan.isInstanceOf[AdaptiveSparkPlanExec] => true
case p => p.expressions.exists(_.exists {
case _: SubqueryExpression => true
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,16 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.columnar.CachedBatch
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
* A query stage is an independent subgraph of the query plan. Query stage materializes its output
* before proceeding with further operators of the query plan. The data statistics of the
* materialized output can be used to optimize subsequent query stages.
*
* There are 2 kinds of query stages:
* 1. Shuffle query stage. This stage materializes its output to shuffle files, and Spark launches
* another job to execute the further operators.
* 2. Broadcast query stage. This stage materializes its output to an array in driver JVM. Spark
* broadcasts the array before executing the further operators.
* A query stage is an independent subgraph of the query plan. AQE framework will materialize its
* output before proceeding with further operators of the query plan. The data statistics of the
* materialized output can be used to optimize the rest of the query plan.
*/
abstract class QueryStageExec extends LeafExecNode {

Expand All @@ -55,18 +51,6 @@ abstract class QueryStageExec extends LeafExecNode {
*/
val plan: SparkPlan

/**
* The canonicalized plan before applying query stage optimizer rules.
*/
val _canonicalized: SparkPlan

/**
* Materialize this query stage, to prepare for the execution, like submitting map stages,
* broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this
* stage is ready.
*/
def doMaterialize(): Future[Any]

/**
* Cancel the stage materialization if in progress; otherwise do nothing.
*/
Expand All @@ -82,7 +66,7 @@ abstract class QueryStageExec extends LeafExecNode {
doMaterialize()
}

def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec
protected def doMaterialize(): Future[Any]

/**
* Returns the runtime statistics after stage materialization.
Expand Down Expand Up @@ -121,7 +105,6 @@ abstract class QueryStageExec extends LeafExecNode {
override def supportsColumnar: Boolean = plan.supportsColumnar
protected override def doExecuteColumnar(): RDD[ColumnarBatch] = plan.executeColumnar()
override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast()
override def doCanonicalize(): SparkPlan = _canonicalized

protected override def stringArgs: Iterator[Any] = Iterator.single(id)

Expand Down Expand Up @@ -158,6 +141,25 @@ abstract class QueryStageExec extends LeafExecNode {
}
}

/**
* There are 2 kinds of reusable query stages:
* 1. Shuffle query stage. This stage materializes its output to shuffle files, and Spark launches
* another job to execute the further operators.
* 2. Broadcast query stage. This stage materializes its output to an array in driver JVM. Spark
* broadcasts the array before executing the further operators.
*/
abstract class ReusableQueryStageExec extends QueryStageExec {

/**
* The canonicalized plan before applying query stage optimizer rules.
*/
val _canonicalized: SparkPlan

override def doCanonicalize(): SparkPlan = _canonicalized

def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec
}

/**
* A shuffle query stage whose child is a [[ShuffleExchangeLike]] or [[ReusedExchangeExec]].
*
Expand All @@ -168,7 +170,7 @@ abstract class QueryStageExec extends LeafExecNode {
case class ShuffleQueryStageExec(
override val id: Int,
override val plan: SparkPlan,
override val _canonicalized: SparkPlan) extends QueryStageExec {
override val _canonicalized: SparkPlan) extends ReusableQueryStageExec {

@transient val shuffle = plan match {
case s: ShuffleExchangeLike => s
Expand All @@ -179,7 +181,7 @@ case class ShuffleQueryStageExec(

@transient private lazy val shuffleFuture = shuffle.submitShuffleJob

override def doMaterialize(): Future[Any] = shuffleFuture
override protected def doMaterialize(): Future[Any] = shuffleFuture

override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = {
val reuse = ShuffleQueryStageExec(
Expand Down Expand Up @@ -219,7 +221,7 @@ case class ShuffleQueryStageExec(
case class BroadcastQueryStageExec(
override val id: Int,
override val plan: SparkPlan,
override val _canonicalized: SparkPlan) extends QueryStageExec {
override val _canonicalized: SparkPlan) extends ReusableQueryStageExec {

@transient val broadcast = plan match {
case b: BroadcastExchangeLike => b
Expand All @@ -228,7 +230,7 @@ case class BroadcastQueryStageExec(
throw new IllegalStateException(s"wrong plan for broadcast stage:\n ${plan.treeString}")
}

override def doMaterialize(): Future[Any] = {
override protected def doMaterialize(): Future[Any] = {
broadcast.submitBroadcastJob
}

Expand All @@ -250,3 +252,44 @@ case class BroadcastQueryStageExec(

override def getRuntimeStatistics: Statistics = broadcast.runtimeStatistics
}

/**
* A table cache query stage whose child is a [[InMemoryTableScanExec]].
*
* @param id the query stage id.
* @param plan the underlying plan.
*/
case class TableCacheQueryStageExec(
override val id: Int,
override val plan: SparkPlan) extends QueryStageExec {

@transient val inMemoryTableScan = plan match {
case i: InMemoryTableScanExec => i
case _ =>
throw new IllegalStateException(s"wrong plan for table cache stage:\n ${plan.treeString}")
}

@transient
private lazy val future: FutureAction[Unit] = {
val rdd = inMemoryTableScan.baseCacheRDD()
sparkContext.submitJob(
rdd,
(_: Iterator[CachedBatch]) => (),
(0 until rdd.getNumPartitions).toSeq,
(_: Int, _: Unit) => (),
()
)
}

override protected def doMaterialize(): Future[Any] = future

override def isMaterialized: Boolean = super.isMaterialized || inMemoryTableScan.isMaterialized

override def cancel(): Unit = {
if (!isMaterialized) {
logDebug(s"Skip canceling the table cache stage: $id")
}
}

override def getRuntimeStatistics: Statistics = inMemoryTableScan.relation.computeStats()
}
Loading