@@ -26,6 +26,7 @@ import scala.collection.mutable
2626import scala .concurrent .ExecutionContext
2727import scala .util .control .NonFatal
2828
29+ import org .apache .spark .SparkException
2930import org .apache .spark .broadcast
3031import org .apache .spark .rdd .RDD
3132import org .apache .spark .sql .SparkSession
@@ -513,14 +514,13 @@ case class AdaptiveSparkPlanExec(
513514 val newPlan = e.withNewChildren(Seq (result.newPlan)).asInstanceOf [Exchange ]
514515 // Create a query stage only when all the child query stages are ready.
515516 if (result.allChildStagesMaterialized) {
516- var newStage = newQueryStage(newPlan)
517- assert(newStage.isInstanceOf [ReusableQueryStageExec ])
517+ var newStage = newQueryStage(newPlan).asInstanceOf [ExchangeQueryStageExec ]
518518 if (conf.exchangeReuseEnabled) {
519519 // Check the `stageCache` again for reuse. If a match is found, ditch the new stage
520520 // and reuse the existing stage found in the `stageCache`, otherwise update the
521521 // `stageCache` with the new stage.
522522 val queryStage = context.stageCache.getOrElseUpdate(
523- newStage.plan.canonicalized, newStage. asInstanceOf [ ReusableQueryStageExec ] )
523+ newStage.plan.canonicalized, newStage)
524524 if (queryStage.ne(newStage)) {
525525 newStage = reuseQueryStage(queryStage, e)
526526 }
@@ -561,49 +561,43 @@ case class AdaptiveSparkPlanExec(
561561 }
562562
563563 private def newQueryStage (plan : SparkPlan ): QueryStageExec = {
564+ val optimizedPlan = plan match {
565+ case e : Exchange =>
566+ e.withNewChildren(Seq (optimizeQueryStage(e.child, isFinalStage = false )))
567+ case _ => plan
568+ }
569+ val newPlan = applyPhysicalRules(
570+ optimizedPlan,
571+ postStageCreationRules(outputsColumnar = plan.supportsColumnar),
572+ Some ((planChangeLogger, " AQE Post Stage Creation" )))
564573 val queryStage = plan match {
565574 case s : ShuffleExchangeLike =>
566- val optimizedPlan = optimizeQueryStage(s.child, isFinalStage = false )
567- val newShuffle = applyPhysicalRules(
568- s.withNewChildren(Seq (optimizedPlan)),
569- postStageCreationRules(outputsColumnar = s.supportsColumnar),
570- Some ((planChangeLogger, " AQE Post Stage Creation" )))
571- if (! newShuffle.isInstanceOf [ShuffleExchangeLike ]) {
572- throw new IllegalStateException (
575+ if (! newPlan.isInstanceOf [ShuffleExchangeLike ]) {
576+ throw SparkException .internalError(
573577 " Custom columnar rules cannot transform shuffle node to something else." )
574578 }
575- ShuffleQueryStageExec (currentStageId, newShuffle , s.canonicalized)
579+ ShuffleQueryStageExec (currentStageId, newPlan , s.canonicalized)
576580 case b : BroadcastExchangeLike =>
577- val optimizedPlan = optimizeQueryStage(b.child, isFinalStage = false )
578- val newBroadcast = applyPhysicalRules(
579- b.withNewChildren(Seq (optimizedPlan)),
580- postStageCreationRules(outputsColumnar = b.supportsColumnar),
581- Some ((planChangeLogger, " AQE Post Stage Creation" )))
582- if (! newBroadcast.isInstanceOf [BroadcastExchangeLike ]) {
583- throw new IllegalStateException (
581+ if (! newPlan.isInstanceOf [BroadcastExchangeLike ]) {
582+ throw SparkException .internalError(
584583 " Custom columnar rules cannot transform broadcast node to something else." )
585584 }
586- BroadcastQueryStageExec (currentStageId, newBroadcast , b.canonicalized)
585+ BroadcastQueryStageExec (currentStageId, newPlan , b.canonicalized)
587586 case i : InMemoryTableScanExec =>
588- val newInMemoryTableScan = applyPhysicalRules(
589- i,
590- postStageCreationRules(outputsColumnar = i.supportsColumnar),
591- Some ((planChangeLogger, " AQE Post Stage Creation" )))
592- if (! newInMemoryTableScan.isInstanceOf [InMemoryTableScanExec ]) {
593- throw new IllegalStateException (" Custom columnar rules cannot transform " +
587+ if (! newPlan.isInstanceOf [InMemoryTableScanExec ]) {
588+ throw SparkException .internalError(" Custom columnar rules cannot transform " +
594589 " `InMemoryTableScanExec` node to something else." )
595590 }
596- TableCacheQueryStageExec (
597- currentStageId, newInMemoryTableScan.asInstanceOf [InMemoryTableScanExec ])
591+ TableCacheQueryStageExec (currentStageId, newPlan.asInstanceOf [InMemoryTableScanExec ])
598592 }
599593 currentStageId += 1
600594 setLogicalLinkForNewQueryStage(queryStage, plan)
601595 queryStage
602596 }
603597
604598 private def reuseQueryStage (
605- existing : ReusableQueryStageExec ,
606- exchange : Exchange ): QueryStageExec = {
599+ existing : ExchangeQueryStageExec ,
600+ exchange : Exchange ): ExchangeQueryStageExec = {
607601 val queryStage = existing.newReuseInstance(currentStageId, exchange.output)
608602 currentStageId += 1
609603 setLogicalLinkForNewQueryStage(queryStage, exchange)
@@ -770,7 +764,7 @@ case class AdaptiveSparkPlanExec(
770764 currentPhysicalPlan.foreach {
771765 // earlyFailedStage is the stage which failed before calling doMaterialize,
772766 // so we should avoid calling cancel on it to re-trigger the failure again.
773- case s : QueryStageExec if ! earlyFailedStage.contains(s.id) =>
767+ case s : ExchangeQueryStageExec if ! earlyFailedStage.contains(s.id) =>
774768 try {
775769 s.cancel()
776770 } catch {
@@ -847,8 +841,8 @@ case class AdaptiveExecutionContext(session: SparkSession, qe: QueryExecution) {
847841 /**
848842 * The exchange-reuse map shared across the entire query, including sub-queries.
849843 */
850- val stageCache : TrieMap [SparkPlan , ReusableQueryStageExec ] =
851- new TrieMap [SparkPlan , ReusableQueryStageExec ]()
844+ val stageCache : TrieMap [SparkPlan , ExchangeQueryStageExec ] =
845+ new TrieMap [SparkPlan , ExchangeQueryStageExec ]()
852846}
853847
854848/**
0 commit comments