diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 69998fdf7e6e..b3f3b74019ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -537,12 +537,13 @@ case class AdaptiveSparkPlanExec( } case i: InMemoryTableScanExec => + // There is no reuse for `InMemoryTableScanExec`, which is different from `Exchange`. If we + // hit it the first time, we should always create a new query stage. val newStage = newQueryStage(i) - val isMaterialized = newStage.isMaterialized CreateStageResult( newPlan = newStage, - allChildStagesMaterialized = isMaterialized, - newStages = if (isMaterialized) Seq.empty else Seq(newStage)) + allChildStagesMaterialized = false, + newStages = Seq(newStage)) case q: QueryStageExec => CreateStageResult(newPlan = q, @@ -561,34 +562,30 @@ case class AdaptiveSparkPlanExec( } private def newQueryStage(plan: SparkPlan): QueryStageExec = { - val optimizedPlan = plan match { - case e: Exchange => - e.withNewChildren(Seq(optimizeQueryStage(e.child, isFinalStage = false))) - case _ => plan - } - val newPlan = applyPhysicalRules( - optimizedPlan, - postStageCreationRules(outputsColumnar = plan.supportsColumnar), - Some((planChangeLogger, "AQE Post Stage Creation"))) val queryStage = plan match { - case s: ShuffleExchangeLike => - if (!newPlan.isInstanceOf[ShuffleExchangeLike]) { - throw SparkException.internalError( - "Custom columnar rules cannot transform shuffle node to something else.") - } - ShuffleQueryStageExec(currentStageId, newPlan, s.canonicalized) - case b: BroadcastExchangeLike => - if (!newPlan.isInstanceOf[BroadcastExchangeLike]) { - throw SparkException.internalError( - "Custom columnar rules cannot transform broadcast node to something else.") + case e: Exchange => + val optimized = e.withNewChildren(Seq(optimizeQueryStage(e.child, isFinalStage = false))) + val newPlan = applyPhysicalRules( + optimized, + postStageCreationRules(outputsColumnar = plan.supportsColumnar), + Some((planChangeLogger, "AQE Post Stage Creation"))) + if (e.isInstanceOf[ShuffleExchangeLike]) { + if (!newPlan.isInstanceOf[ShuffleExchangeLike]) { + throw SparkException.internalError( + "Custom columnar rules cannot transform shuffle node to something else.") + } + ShuffleQueryStageExec(currentStageId, newPlan, e.canonicalized) + } else { + assert(e.isInstanceOf[BroadcastExchangeLike]) + if (!newPlan.isInstanceOf[BroadcastExchangeLike]) { + throw SparkException.internalError( + "Custom columnar rules cannot transform broadcast node to something else.") + } + BroadcastQueryStageExec(currentStageId, newPlan, e.canonicalized) } - BroadcastQueryStageExec(currentStageId, newPlan, b.canonicalized) case i: InMemoryTableScanExec => - if (!newPlan.isInstanceOf[InMemoryTableScanExec]) { - throw SparkException.internalError("Custom columnar rules cannot transform " + - "`InMemoryTableScanExec` node to something else.") - } - TableCacheQueryStageExec(currentStageId, newPlan.asInstanceOf[InMemoryTableScanExec]) + // No need to optimize `InMemoryTableScanExec` as it's a leaf node. + TableCacheQueryStageExec(currentStageId, i) } currentStageId += 1 setLogicalLinkForNewQueryStage(queryStage, plan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index 633142170e1f..a06352557061 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -96,6 +96,10 @@ case class InsertAdaptiveSparkPlan( plan.exists { case _: Exchange => true case p if !p.requiredChildDistribution.forall(_ == UnspecifiedDistribution) => true + // AQE framework has a different way to update the query plan in the UI: it updates the plan + // at the end of execution, while non-AQE updates the plan before execution. If the cached + // plan is already AQEed, the current plan must be AQEed as well so that the UI can get plan + // update correctly. case i: InMemoryTableScanExec if i.relation.cachedPlan.isInstanceOf[AdaptiveSparkPlanExec] => true case p => p.expressions.exists(_.exists { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index a27f783215e1..d48b4fe17517 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -86,7 +86,7 @@ abstract class QueryStageExec extends LeafExecNode { protected var _resultOption = new AtomicReference[Option[Any]](None) private[adaptive] def resultOption: AtomicReference[Option[Any]] = _resultOption - def isMaterialized: Boolean = resultOption.get().isDefined + final def isMaterialized: Boolean = resultOption.get().isDefined override def output: Seq[Attribute] = plan.output override def outputPartitioning: Partitioning = plan.outputPartitioning @@ -275,20 +275,22 @@ case class TableCacheQueryStageExec( } @transient - private lazy val future: FutureAction[Unit] = { - val rdd = inMemoryTableScan.baseCacheRDD() - sparkContext.submitJob( - rdd, - (_: Iterator[CachedBatch]) => (), - (0 until rdd.getNumPartitions).toSeq, - (_: Int, _: Unit) => (), - () - ) + private lazy val future: Future[Unit] = { + if (inMemoryTableScan.isMaterialized) { + Future.successful(()) + } else { + val rdd = inMemoryTableScan.baseCacheRDD() + sparkContext.submitJob( + rdd, + (_: Iterator[CachedBatch]) => (), + (0 until rdd.getNumPartitions).toSeq, + (_: Int, _: Unit) => (), + () + ) + } } override protected def doMaterialize(): Future[Any] = future - override def isMaterialized: Boolean = super.isMaterialized || inMemoryTableScan.isMaterialized - override def getRuntimeStatistics: Statistics = inMemoryTableScan.relation.computeStats() }