diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 7c585aa3da4b8..69998fdf7e6e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -26,6 +26,7 @@ import scala.collection.mutable import scala.concurrent.ExecutionContext import scala.util.control.NonFatal +import org.apache.spark.SparkException import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession @@ -513,14 +514,13 @@ case class AdaptiveSparkPlanExec( val newPlan = e.withNewChildren(Seq(result.newPlan)).asInstanceOf[Exchange] // Create a query stage only when all the child query stages are ready. if (result.allChildStagesMaterialized) { - var newStage = newQueryStage(newPlan) - assert(newStage.isInstanceOf[ReusableQueryStageExec]) + var newStage = newQueryStage(newPlan).asInstanceOf[ExchangeQueryStageExec] if (conf.exchangeReuseEnabled) { // Check the `stageCache` again for reuse. If a match is found, ditch the new stage // and reuse the existing stage found in the `stageCache`, otherwise update the // `stageCache` with the new stage. val queryStage = context.stageCache.getOrElseUpdate( - newStage.plan.canonicalized, newStage.asInstanceOf[ReusableQueryStageExec]) + newStage.plan.canonicalized, newStage) if (queryStage.ne(newStage)) { newStage = reuseQueryStage(queryStage, e) } @@ -561,40 +561,34 @@ case class AdaptiveSparkPlanExec( } private def newQueryStage(plan: SparkPlan): QueryStageExec = { + val optimizedPlan = plan match { + case e: Exchange => + e.withNewChildren(Seq(optimizeQueryStage(e.child, isFinalStage = false))) + case _ => plan + } + val newPlan = applyPhysicalRules( + optimizedPlan, + postStageCreationRules(outputsColumnar = plan.supportsColumnar), + Some((planChangeLogger, "AQE Post Stage Creation"))) val queryStage = plan match { case s: ShuffleExchangeLike => - val optimizedPlan = optimizeQueryStage(s.child, isFinalStage = false) - val newShuffle = applyPhysicalRules( - s.withNewChildren(Seq(optimizedPlan)), - postStageCreationRules(outputsColumnar = s.supportsColumnar), - Some((planChangeLogger, "AQE Post Stage Creation"))) - if (!newShuffle.isInstanceOf[ShuffleExchangeLike]) { - throw new IllegalStateException( + if (!newPlan.isInstanceOf[ShuffleExchangeLike]) { + throw SparkException.internalError( "Custom columnar rules cannot transform shuffle node to something else.") } - ShuffleQueryStageExec(currentStageId, newShuffle, s.canonicalized) + ShuffleQueryStageExec(currentStageId, newPlan, s.canonicalized) case b: BroadcastExchangeLike => - val optimizedPlan = optimizeQueryStage(b.child, isFinalStage = false) - val newBroadcast = applyPhysicalRules( - b.withNewChildren(Seq(optimizedPlan)), - postStageCreationRules(outputsColumnar = b.supportsColumnar), - Some((planChangeLogger, "AQE Post Stage Creation"))) - if (!newBroadcast.isInstanceOf[BroadcastExchangeLike]) { - throw new IllegalStateException( + if (!newPlan.isInstanceOf[BroadcastExchangeLike]) { + throw SparkException.internalError( "Custom columnar rules cannot transform broadcast node to something else.") } - BroadcastQueryStageExec(currentStageId, newBroadcast, b.canonicalized) + BroadcastQueryStageExec(currentStageId, newPlan, b.canonicalized) case i: InMemoryTableScanExec => - val newInMemoryTableScan = applyPhysicalRules( - i, - postStageCreationRules(outputsColumnar = i.supportsColumnar), - Some((planChangeLogger, "AQE Post Stage Creation"))) - if (!newInMemoryTableScan.isInstanceOf[InMemoryTableScanExec]) { - throw new IllegalStateException("Custom columnar rules cannot transform " + + if (!newPlan.isInstanceOf[InMemoryTableScanExec]) { + throw SparkException.internalError("Custom columnar rules cannot transform " + "`InMemoryTableScanExec` node to something else.") } - TableCacheQueryStageExec( - currentStageId, newInMemoryTableScan.asInstanceOf[InMemoryTableScanExec]) + TableCacheQueryStageExec(currentStageId, newPlan.asInstanceOf[InMemoryTableScanExec]) } currentStageId += 1 setLogicalLinkForNewQueryStage(queryStage, plan) @@ -602,8 +596,8 @@ case class AdaptiveSparkPlanExec( } private def reuseQueryStage( - existing: ReusableQueryStageExec, - exchange: Exchange): QueryStageExec = { + existing: ExchangeQueryStageExec, + exchange: Exchange): ExchangeQueryStageExec = { val queryStage = existing.newReuseInstance(currentStageId, exchange.output) currentStageId += 1 setLogicalLinkForNewQueryStage(queryStage, exchange) @@ -770,7 +764,7 @@ case class AdaptiveSparkPlanExec( currentPhysicalPlan.foreach { // earlyFailedStage is the stage which failed before calling doMaterialize, // so we should avoid calling cancel on it to re-trigger the failure again. - case s: QueryStageExec if !earlyFailedStage.contains(s.id) => + case s: ExchangeQueryStageExec if !earlyFailedStage.contains(s.id) => try { s.cancel() } catch { @@ -847,8 +841,8 @@ case class AdaptiveExecutionContext(session: SparkSession, qe: QueryExecution) { /** * The exchange-reuse map shared across the entire query, including sub-queries. */ - val stageCache: TrieMap[SparkPlan, ReusableQueryStageExec] = - new TrieMap[SparkPlan, ReusableQueryStageExec]() + val stageCache: TrieMap[SparkPlan, ExchangeQueryStageExec] = + new TrieMap[SparkPlan, ExchangeQueryStageExec]() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index b40206f37496e..72e7fc937f282 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -51,11 +51,6 @@ abstract class QueryStageExec extends LeafExecNode { */ val plan: SparkPlan - /** - * Cancel the stage materialization if in progress; otherwise do nothing. - */ - def cancel(): Unit - /** * Materialize this query stage, to prepare for the execution, like submitting map stages, * broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this @@ -142,13 +137,18 @@ abstract class QueryStageExec extends LeafExecNode { } /** - * There are 2 kinds of reusable query stages: + * There are 2 kinds of exchange query stages: * 1. Shuffle query stage. This stage materializes its output to shuffle files, and Spark launches * another job to execute the further operators. * 2. Broadcast query stage. This stage materializes its output to an array in driver JVM. Spark * broadcasts the array before executing the further operators. */ -abstract class ReusableQueryStageExec extends QueryStageExec { +abstract class ExchangeQueryStageExec extends QueryStageExec { + + /** + * Cancel the stage materialization if in progress; otherwise do nothing. + */ + def cancel(): Unit /** * The canonicalized plan before applying query stage optimizer rules. @@ -157,7 +157,7 @@ abstract class ReusableQueryStageExec extends QueryStageExec { override def doCanonicalize(): SparkPlan = _canonicalized - def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec + def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec } /** @@ -170,7 +170,7 @@ abstract class ReusableQueryStageExec extends QueryStageExec { case class ShuffleQueryStageExec( override val id: Int, override val plan: SparkPlan, - override val _canonicalized: SparkPlan) extends ReusableQueryStageExec { + override val _canonicalized: SparkPlan) extends ExchangeQueryStageExec { @transient val shuffle = plan match { case s: ShuffleExchangeLike => s @@ -183,7 +183,8 @@ case class ShuffleQueryStageExec( override protected def doMaterialize(): Future[Any] = shuffleFuture - override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = { + override def newReuseInstance( + newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec = { val reuse = ShuffleQueryStageExec( newStageId, ReusedExchangeExec(newOutput, shuffle), @@ -221,7 +222,7 @@ case class ShuffleQueryStageExec( case class BroadcastQueryStageExec( override val id: Int, override val plan: SparkPlan, - override val _canonicalized: SparkPlan) extends ReusableQueryStageExec { + override val _canonicalized: SparkPlan) extends ExchangeQueryStageExec { @transient val broadcast = plan match { case b: BroadcastExchangeLike => b @@ -234,7 +235,8 @@ case class BroadcastQueryStageExec( broadcast.submitBroadcastJob } - override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = { + override def newReuseInstance( + newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec = { val reuse = BroadcastQueryStageExec( newStageId, ReusedExchangeExec(newOutput, broadcast), @@ -285,11 +287,5 @@ case class TableCacheQueryStageExec( override def isMaterialized: Boolean = super.isMaterialized || inMemoryTableScan.isMaterialized - override def cancel(): Unit = { - if (!isMaterialized) { - logDebug(s"Skip canceling the table cache stage: $id") - } - } - override def getRuntimeStatistics: Statistics = inMemoryTableScan.relation.computeStats() }