diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 22b653c281fa..7c585aa3da4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ import org.apache.spark.sql.execution.bucketing.DisableUnnecessaryBucketedScan +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric} import org.apache.spark.sql.internal.SQLConf @@ -220,10 +221,24 @@ case class AdaptiveSparkPlanExec( } private def getExecutionId: Option[Long] = { - // If the `QueryExecution` does not match the current execution ID, it means the execution ID - // belongs to another (parent) query, and we should not call update UI in this query. Option(context.session.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)) - .map(_.toLong).filter(SQLExecution.getQueryExecution(_) eq context.qe) + .map(_.toLong) + } + + private lazy val shouldUpdatePlan: Boolean = { + // There are two cases that should not update plan: + // 1. When executing subqueries, we can't update the query plan in the UI as the + // UI doesn't support partial update yet. However, the subquery may have been + // optimized into a different plan and we must let the UI know the SQL metrics + // of the new plan nodes, so that it can track the valid accumulator updates later + // and display SQL metrics correctly. + // 2. If the `QueryExecution` does not match the current execution ID, it means the execution + // ID belongs to another (parent) query, and we should not call update UI in this query. + // e.g., a nested `AdaptiveSparkPlanExec` in `InMemoryTableScanExec`. + // + // That means only the root `AdaptiveSparkPlanExec` of the main query that triggers this + // query execution need to do a plan update for the UI. + !isSubquery && getExecutionId.exists(SQLExecution.getQueryExecution(_) eq context.qe) } def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity) @@ -345,7 +360,7 @@ case class AdaptiveSparkPlanExec( // Subqueries that don't belong to any query stage of the main query will execute after the // last UI update in `getFinalPhysicalPlan`, so we need to update UI here again to make sure // the newly generated nodes of those subqueries are updated. - if (!isSubquery && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) { + if (shouldUpdatePlan && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) { getExecutionId.foreach(onUpdatePlan(_, Seq.empty)) } logOnLevel(s"Final plan:\n$currentPhysicalPlan") @@ -499,12 +514,13 @@ case class AdaptiveSparkPlanExec( // Create a query stage only when all the child query stages are ready. if (result.allChildStagesMaterialized) { var newStage = newQueryStage(newPlan) + assert(newStage.isInstanceOf[ReusableQueryStageExec]) if (conf.exchangeReuseEnabled) { // Check the `stageCache` again for reuse. If a match is found, ditch the new stage // and reuse the existing stage found in the `stageCache`, otherwise update the // `stageCache` with the new stage. val queryStage = context.stageCache.getOrElseUpdate( - newStage.plan.canonicalized, newStage) + newStage.plan.canonicalized, newStage.asInstanceOf[ReusableQueryStageExec]) if (queryStage.ne(newStage)) { newStage = reuseQueryStage(queryStage, e) } @@ -520,6 +536,14 @@ case class AdaptiveSparkPlanExec( } } + case i: InMemoryTableScanExec => + val newStage = newQueryStage(i) + val isMaterialized = newStage.isMaterialized + CreateStageResult( + newPlan = newStage, + allChildStagesMaterialized = isMaterialized, + newStages = if (isMaterialized) Seq.empty else Seq(newStage)) + case q: QueryStageExec => CreateStageResult(newPlan = q, allChildStagesMaterialized = q.isMaterialized, newStages = Seq.empty) @@ -536,10 +560,10 @@ case class AdaptiveSparkPlanExec( } } - private def newQueryStage(e: Exchange): QueryStageExec = { - val optimizedPlan = optimizeQueryStage(e.child, isFinalStage = false) - val queryStage = e match { + private def newQueryStage(plan: SparkPlan): QueryStageExec = { + val queryStage = plan match { case s: ShuffleExchangeLike => + val optimizedPlan = optimizeQueryStage(s.child, isFinalStage = false) val newShuffle = applyPhysicalRules( s.withNewChildren(Seq(optimizedPlan)), postStageCreationRules(outputsColumnar = s.supportsColumnar), @@ -550,6 +574,7 @@ case class AdaptiveSparkPlanExec( } ShuffleQueryStageExec(currentStageId, newShuffle, s.canonicalized) case b: BroadcastExchangeLike => + val optimizedPlan = optimizeQueryStage(b.child, isFinalStage = false) val newBroadcast = applyPhysicalRules( b.withNewChildren(Seq(optimizedPlan)), postStageCreationRules(outputsColumnar = b.supportsColumnar), @@ -559,13 +584,26 @@ case class AdaptiveSparkPlanExec( "Custom columnar rules cannot transform broadcast node to something else.") } BroadcastQueryStageExec(currentStageId, newBroadcast, b.canonicalized) + case i: InMemoryTableScanExec => + val newInMemoryTableScan = applyPhysicalRules( + i, + postStageCreationRules(outputsColumnar = i.supportsColumnar), + Some((planChangeLogger, "AQE Post Stage Creation"))) + if (!newInMemoryTableScan.isInstanceOf[InMemoryTableScanExec]) { + throw new IllegalStateException("Custom columnar rules cannot transform " + + "`InMemoryTableScanExec` node to something else.") + } + TableCacheQueryStageExec( + currentStageId, newInMemoryTableScan.asInstanceOf[InMemoryTableScanExec]) } currentStageId += 1 - setLogicalLinkForNewQueryStage(queryStage, e) + setLogicalLinkForNewQueryStage(queryStage, plan) queryStage } - private def reuseQueryStage(existing: QueryStageExec, exchange: Exchange): QueryStageExec = { + private def reuseQueryStage( + existing: ReusableQueryStageExec, + exchange: Exchange): QueryStageExec = { val queryStage = existing.newReuseInstance(currentStageId, exchange.output) currentStageId += 1 setLogicalLinkForNewQueryStage(queryStage, exchange) @@ -707,12 +745,7 @@ case class AdaptiveSparkPlanExec( * Notify the listeners of the physical plan change. */ private def onUpdatePlan(executionId: Long, newSubPlans: Seq[SparkPlan]): Unit = { - if (isSubquery) { - // When executing subqueries, we can't update the query plan in the UI as the - // UI doesn't support partial update yet. However, the subquery may have been - // optimized into a different plan and we must let the UI know the SQL metrics - // of the new plan nodes, so that it can track the valid accumulator updates later - // and display SQL metrics correctly. + if (!shouldUpdatePlan) { val newMetrics = newSubPlans.flatMap { p => p.flatMap(_.metrics.values.map(m => SQLPlanMetric(m.name.get, m.id, m.metricType))) } @@ -814,8 +847,8 @@ case class AdaptiveExecutionContext(session: SparkSession, qe: QueryExecution) { /** * The exchange-reuse map shared across the entire query, including sub-queries. */ - val stageCache: TrieMap[SparkPlan, QueryStageExec] = - new TrieMap[SparkPlan, QueryStageExec]() + val stageCache: TrieMap[SparkPlan, ReusableQueryStageExec] = + new TrieMap[SparkPlan, ReusableQueryStageExec]() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index 939d245304b7..633142170e1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} import org.apache.spark.sql.execution.datasources.V1WriteCommand import org.apache.spark.sql.execution.datasources.v2.V2CommandExec @@ -88,12 +89,15 @@ case class InsertAdaptiveSparkPlan( // - The query may need to add exchanges. It's an overkill to run `EnsureRequirements` here, so // we just check `SparkPlan.requiredChildDistribution` and see if it's possible that the // the query needs to add exchanges later. + // - The query contains nested `AdaptiveSparkPlanExec`. // - The query contains sub-query. private def shouldApplyAQE(plan: SparkPlan, isSubquery: Boolean): Boolean = { conf.getConf(SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY) || isSubquery || { plan.exists { case _: Exchange => true case p if !p.requiredChildDistribution.forall(_ == UnspecifiedDistribution) => true + case i: InMemoryTableScanExec + if i.relation.cachedPlan.isInstanceOf[AdaptiveSparkPlanExec] => true case p => p.expressions.exists(_.exists { case _: SubqueryExpression => true case _ => false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index 0aee6c21f863..b40206f37496 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -28,20 +28,16 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.columnar.CachedBatch import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.vectorized.ColumnarBatch /** - * A query stage is an independent subgraph of the query plan. Query stage materializes its output - * before proceeding with further operators of the query plan. The data statistics of the - * materialized output can be used to optimize subsequent query stages. - * - * There are 2 kinds of query stages: - * 1. Shuffle query stage. This stage materializes its output to shuffle files, and Spark launches - * another job to execute the further operators. - * 2. Broadcast query stage. This stage materializes its output to an array in driver JVM. Spark - * broadcasts the array before executing the further operators. + * A query stage is an independent subgraph of the query plan. AQE framework will materialize its + * output before proceeding with further operators of the query plan. The data statistics of the + * materialized output can be used to optimize the rest of the query plan. */ abstract class QueryStageExec extends LeafExecNode { @@ -55,18 +51,6 @@ abstract class QueryStageExec extends LeafExecNode { */ val plan: SparkPlan - /** - * The canonicalized plan before applying query stage optimizer rules. - */ - val _canonicalized: SparkPlan - - /** - * Materialize this query stage, to prepare for the execution, like submitting map stages, - * broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this - * stage is ready. - */ - def doMaterialize(): Future[Any] - /** * Cancel the stage materialization if in progress; otherwise do nothing. */ @@ -82,7 +66,7 @@ abstract class QueryStageExec extends LeafExecNode { doMaterialize() } - def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec + protected def doMaterialize(): Future[Any] /** * Returns the runtime statistics after stage materialization. @@ -121,7 +105,6 @@ abstract class QueryStageExec extends LeafExecNode { override def supportsColumnar: Boolean = plan.supportsColumnar protected override def doExecuteColumnar(): RDD[ColumnarBatch] = plan.executeColumnar() override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast() - override def doCanonicalize(): SparkPlan = _canonicalized protected override def stringArgs: Iterator[Any] = Iterator.single(id) @@ -158,6 +141,25 @@ abstract class QueryStageExec extends LeafExecNode { } } +/** + * There are 2 kinds of reusable query stages: + * 1. Shuffle query stage. This stage materializes its output to shuffle files, and Spark launches + * another job to execute the further operators. + * 2. Broadcast query stage. This stage materializes its output to an array in driver JVM. Spark + * broadcasts the array before executing the further operators. + */ +abstract class ReusableQueryStageExec extends QueryStageExec { + + /** + * The canonicalized plan before applying query stage optimizer rules. + */ + val _canonicalized: SparkPlan + + override def doCanonicalize(): SparkPlan = _canonicalized + + def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec +} + /** * A shuffle query stage whose child is a [[ShuffleExchangeLike]] or [[ReusedExchangeExec]]. * @@ -168,7 +170,7 @@ abstract class QueryStageExec extends LeafExecNode { case class ShuffleQueryStageExec( override val id: Int, override val plan: SparkPlan, - override val _canonicalized: SparkPlan) extends QueryStageExec { + override val _canonicalized: SparkPlan) extends ReusableQueryStageExec { @transient val shuffle = plan match { case s: ShuffleExchangeLike => s @@ -179,7 +181,7 @@ case class ShuffleQueryStageExec( @transient private lazy val shuffleFuture = shuffle.submitShuffleJob - override def doMaterialize(): Future[Any] = shuffleFuture + override protected def doMaterialize(): Future[Any] = shuffleFuture override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = { val reuse = ShuffleQueryStageExec( @@ -219,7 +221,7 @@ case class ShuffleQueryStageExec( case class BroadcastQueryStageExec( override val id: Int, override val plan: SparkPlan, - override val _canonicalized: SparkPlan) extends QueryStageExec { + override val _canonicalized: SparkPlan) extends ReusableQueryStageExec { @transient val broadcast = plan match { case b: BroadcastExchangeLike => b @@ -228,7 +230,7 @@ case class BroadcastQueryStageExec( throw new IllegalStateException(s"wrong plan for broadcast stage:\n ${plan.treeString}") } - override def doMaterialize(): Future[Any] = { + override protected def doMaterialize(): Future[Any] = { broadcast.submitBroadcastJob } @@ -250,3 +252,44 @@ case class BroadcastQueryStageExec( override def getRuntimeStatistics: Statistics = broadcast.runtimeStatistics } + +/** + * A table cache query stage whose child is a [[InMemoryTableScanExec]]. + * + * @param id the query stage id. + * @param plan the underlying plan. + */ +case class TableCacheQueryStageExec( + override val id: Int, + override val plan: SparkPlan) extends QueryStageExec { + + @transient val inMemoryTableScan = plan match { + case i: InMemoryTableScanExec => i + case _ => + throw new IllegalStateException(s"wrong plan for table cache stage:\n ${plan.treeString}") + } + + @transient + private lazy val future: FutureAction[Unit] = { + val rdd = inMemoryTableScan.baseCacheRDD() + sparkContext.submitJob( + rdd, + (_: Iterator[CachedBatch]) => (), + (0 until rdd.getNumPartitions).toSeq, + (_: Int, _: Unit) => (), + () + ) + } + + override protected def doMaterialize(): Future[Any] = future + + override def isMaterialized: Boolean = super.isMaterialized || inMemoryTableScan.isMaterialized + + override def cancel(): Unit = { + if (!isMaterialized) { + logDebug(s"Skip canceling the table cache stage: $id") + } + } + + override def getRuntimeStatistics: Statistics = inMemoryTableScan.relation.computeStats() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 98f4a164a22c..07f9dfb1d8a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.commons.lang3.StringUtils -import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.TaskContext import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapCol import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.types.{BooleanType, ByteType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructType, UserDefinedType} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} -import org.apache.spark.storage.{RDDBlockId, StorageLevel} +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{LongAccumulator, Utils} /** @@ -211,6 +211,7 @@ case class CachedRDDBuilder( val sizeInBytesStats: LongAccumulator = cachedPlan.session.sparkContext.longAccumulator val rowCountStats: LongAccumulator = cachedPlan.session.sparkContext.longAccumulator + private val materializedPartitions = cachedPlan.session.sparkContext.longAccumulator val cachedName = tableName.map(n => s"In-memory table $n") .getOrElse(StringUtils.abbreviate(cachedPlan.toString, 1024)) @@ -247,16 +248,12 @@ case class CachedRDDBuilder( } private def isCachedRDDLoaded: Boolean = { - _cachedColumnBuffersAreLoaded || { - val bmMaster = SparkEnv.get.blockManager.master - val rddLoaded = _cachedColumnBuffers.partitions.forall { partition => - bmMaster.getBlockStatus(RDDBlockId(_cachedColumnBuffers.id, partition.index), false) - .exists { case(_, blockStatus) => blockStatus.isCached } - } - if (rddLoaded) { - _cachedColumnBuffersAreLoaded = rddLoaded - } - rddLoaded + _cachedColumnBuffersAreLoaded || { + val rddLoaded = _cachedColumnBuffers.partitions.length == materializedPartitions.value + if (rddLoaded) { + _cachedColumnBuffersAreLoaded = rddLoaded + } + rddLoaded } } @@ -275,10 +272,19 @@ case class CachedRDDBuilder( storageLevel, cachedPlan.conf) } - val cached = cb.map { batch => - sizeInBytesStats.add(batch.sizeInBytes) - rowCountStats.add(batch.numRows) - batch + val cached = cb.mapPartitionsInternal { it => + TaskContext.get().addTaskCompletionListener[Unit](_ => { + materializedPartitions.add(1L) + }) + new Iterator[CachedBatch] { + override def hasNext: Boolean = it.hasNext + override def next(): CachedBatch = { + val batch = it.next() + sizeInBytesStats.add(batch.sizeInBytes) + rowCountStats.add(batch.numRows) + batch + } + } }.persist(storageLevel) cached.setName(cachedName) cached diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 0f00a6a3559b..08244a4f84fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -166,4 +166,14 @@ case class InMemoryTableScanExec( protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { columnarInputRDD } + + def isMaterialized: Boolean = relation.cacheBuilder.isCachedColumnBuffersLoaded + + /** + * This method is only used by AQE which executes the actually cached RDD that without filter and + * serialization of row/columnar. + */ + def baseCacheRDD(): RDD[CachedBatch] = { + relation.cacheBuilder.cachedColumnBuffers + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 6163e26e49cd..1504207d39cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -106,6 +106,9 @@ object SparkPlanGraph { buildSparkPlanGraphNode( planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) } + case "TableCacheQueryStage" => + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) case "Subquery" if subgraph != null => // Subquery should not be included in WholeStageCodegen buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 988695e2667b..d2fe588c9a5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -31,13 +31,14 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.noop.NoopDataSource import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter -import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLExecutionStart} +import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SparkListenerSQLExecutionStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode @@ -59,6 +60,7 @@ class AdaptiveQueryExecSuite private def runAdaptiveAndVerifyResult(query: String): (SparkPlan, SparkPlan) = { var finalPlanCnt = 0 + var hasMetricsEvent = false val listener = new SparkListener { override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { @@ -67,6 +69,8 @@ class AdaptiveQueryExecSuite "AdaptiveSparkPlan isFinalPlan=true")) { finalPlanCnt += 1 } + case _: SparkListenerSQLAdaptiveSQLMetricUpdates => + hasMetricsEvent = true case _ => // ignore other events } } @@ -92,6 +96,10 @@ class AdaptiveQueryExecSuite assert(finalPlanCnt == expectedFinalPlanCnt) spark.sparkContext.removeSparkListener(listener) + val expectedMetrics = findInMemoryTable(planAfter).nonEmpty || + subqueriesAll(planAfter).nonEmpty + assert(hasMetricsEvent == expectedMetrics) + val exchanges = adaptivePlan.collect { case e: Exchange => e } @@ -160,6 +168,13 @@ class AdaptiveQueryExecSuite } } + private def findInMemoryTable(plan: SparkPlan): Seq[InMemoryTableScanExec] = { + collect(plan) { + case c: InMemoryTableScanExec + if c.relation.cachedPlan.isInstanceOf[AdaptiveSparkPlanExec] => c + } + } + private def checkNumLocalShuffleReads( plan: SparkPlan, numShufflesWithoutLocalRead: Int = 0): Unit = { val numShuffles = collect(plan) { @@ -2700,6 +2715,56 @@ class AdaptiveQueryExecSuite assert(df.rdd.getNumPartitions == 3) } } + + test("SPARK-42101: Apply AQE if contains nested AdaptiveSparkPlanExec") { + withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true") { + val df = spark.range(3).repartition().cache() + assert(df.sortWithinPartitions("id") + .queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec]) + } + } + + test("SPARK-42101: Make AQE support InMemoryTableScanExec") { + withSQLConf( + SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.range(10).selectExpr("cast(id as string) c1") + val df2 = spark.range(10).selectExpr("cast(id as string) c2") + val cached = df1.join(df2, $"c1" === $"c2").cache() + + def checkShuffleAndSort(firstAccess: Boolean): Unit = { + val df = cached.groupBy("c1").agg(max($"c2")) + val initialExecutedPlan = df.queryExecution.executedPlan + assert(collect(initialExecutedPlan) { + case s: ShuffleExchangeLike => s + }.size == (if (firstAccess) 1 else 0)) + assert(collect(initialExecutedPlan) { + case s: SortExec => s + }.size == (if (firstAccess) 2 else 0)) + assert(collect(initialExecutedPlan) { + case i: InMemoryTableScanExec => i + }.head.isMaterialized != firstAccess) + + df.collect() + val finalExecutedPlan = df.queryExecution.executedPlan + assert(collect(finalExecutedPlan) { + case s: ShuffleExchangeLike => s + }.isEmpty) + assert(collect(finalExecutedPlan) { + case s: SortExec => s + }.isEmpty) + assert(collect(initialExecutedPlan) { + case i: InMemoryTableScanExec => i + }.head.isMaterialized) + } + + // first access cache + checkShuffleAndSort(firstAccess = true) + + // access a materialized cache + checkShuffleAndSort(firstAccess = false) + } + } } /**