Skip to content

Commit 7fccd67

Browse files
committed
[SPARK-42101][SQL][FOLLOWUP] Make QueryStageExec more type safe
### What changes were proposed in this pull request? This is a followup of #39624 . `TableCacheQueryStageExec.cancel` is a noop and we can move `def cancel` out from `QueryStageExec`. Due to this movement, I renamed `ReusableQueryStageExec` to `ExchangeQueryStageExec` ### Why are the changes needed? type safe ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests Closes #40399 from cloud-fan/follow. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 9c7aa90 commit 7fccd67

File tree

2 files changed

+40
-50
lines changed

2 files changed

+40
-50
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import scala.collection.mutable
2626
import scala.concurrent.ExecutionContext
2727
import scala.util.control.NonFatal
2828

29+
import org.apache.spark.SparkException
2930
import org.apache.spark.broadcast
3031
import org.apache.spark.rdd.RDD
3132
import org.apache.spark.sql.SparkSession
@@ -513,14 +514,13 @@ case class AdaptiveSparkPlanExec(
513514
val newPlan = e.withNewChildren(Seq(result.newPlan)).asInstanceOf[Exchange]
514515
// Create a query stage only when all the child query stages are ready.
515516
if (result.allChildStagesMaterialized) {
516-
var newStage = newQueryStage(newPlan)
517-
assert(newStage.isInstanceOf[ReusableQueryStageExec])
517+
var newStage = newQueryStage(newPlan).asInstanceOf[ExchangeQueryStageExec]
518518
if (conf.exchangeReuseEnabled) {
519519
// Check the `stageCache` again for reuse. If a match is found, ditch the new stage
520520
// and reuse the existing stage found in the `stageCache`, otherwise update the
521521
// `stageCache` with the new stage.
522522
val queryStage = context.stageCache.getOrElseUpdate(
523-
newStage.plan.canonicalized, newStage.asInstanceOf[ReusableQueryStageExec])
523+
newStage.plan.canonicalized, newStage)
524524
if (queryStage.ne(newStage)) {
525525
newStage = reuseQueryStage(queryStage, e)
526526
}
@@ -561,49 +561,43 @@ case class AdaptiveSparkPlanExec(
561561
}
562562

563563
private def newQueryStage(plan: SparkPlan): QueryStageExec = {
564+
val optimizedPlan = plan match {
565+
case e: Exchange =>
566+
e.withNewChildren(Seq(optimizeQueryStage(e.child, isFinalStage = false)))
567+
case _ => plan
568+
}
569+
val newPlan = applyPhysicalRules(
570+
optimizedPlan,
571+
postStageCreationRules(outputsColumnar = plan.supportsColumnar),
572+
Some((planChangeLogger, "AQE Post Stage Creation")))
564573
val queryStage = plan match {
565574
case s: ShuffleExchangeLike =>
566-
val optimizedPlan = optimizeQueryStage(s.child, isFinalStage = false)
567-
val newShuffle = applyPhysicalRules(
568-
s.withNewChildren(Seq(optimizedPlan)),
569-
postStageCreationRules(outputsColumnar = s.supportsColumnar),
570-
Some((planChangeLogger, "AQE Post Stage Creation")))
571-
if (!newShuffle.isInstanceOf[ShuffleExchangeLike]) {
572-
throw new IllegalStateException(
575+
if (!newPlan.isInstanceOf[ShuffleExchangeLike]) {
576+
throw SparkException.internalError(
573577
"Custom columnar rules cannot transform shuffle node to something else.")
574578
}
575-
ShuffleQueryStageExec(currentStageId, newShuffle, s.canonicalized)
579+
ShuffleQueryStageExec(currentStageId, newPlan, s.canonicalized)
576580
case b: BroadcastExchangeLike =>
577-
val optimizedPlan = optimizeQueryStage(b.child, isFinalStage = false)
578-
val newBroadcast = applyPhysicalRules(
579-
b.withNewChildren(Seq(optimizedPlan)),
580-
postStageCreationRules(outputsColumnar = b.supportsColumnar),
581-
Some((planChangeLogger, "AQE Post Stage Creation")))
582-
if (!newBroadcast.isInstanceOf[BroadcastExchangeLike]) {
583-
throw new IllegalStateException(
581+
if (!newPlan.isInstanceOf[BroadcastExchangeLike]) {
582+
throw SparkException.internalError(
584583
"Custom columnar rules cannot transform broadcast node to something else.")
585584
}
586-
BroadcastQueryStageExec(currentStageId, newBroadcast, b.canonicalized)
585+
BroadcastQueryStageExec(currentStageId, newPlan, b.canonicalized)
587586
case i: InMemoryTableScanExec =>
588-
val newInMemoryTableScan = applyPhysicalRules(
589-
i,
590-
postStageCreationRules(outputsColumnar = i.supportsColumnar),
591-
Some((planChangeLogger, "AQE Post Stage Creation")))
592-
if (!newInMemoryTableScan.isInstanceOf[InMemoryTableScanExec]) {
593-
throw new IllegalStateException("Custom columnar rules cannot transform " +
587+
if (!newPlan.isInstanceOf[InMemoryTableScanExec]) {
588+
throw SparkException.internalError("Custom columnar rules cannot transform " +
594589
"`InMemoryTableScanExec` node to something else.")
595590
}
596-
TableCacheQueryStageExec(
597-
currentStageId, newInMemoryTableScan.asInstanceOf[InMemoryTableScanExec])
591+
TableCacheQueryStageExec(currentStageId, newPlan.asInstanceOf[InMemoryTableScanExec])
598592
}
599593
currentStageId += 1
600594
setLogicalLinkForNewQueryStage(queryStage, plan)
601595
queryStage
602596
}
603597

604598
private def reuseQueryStage(
605-
existing: ReusableQueryStageExec,
606-
exchange: Exchange): QueryStageExec = {
599+
existing: ExchangeQueryStageExec,
600+
exchange: Exchange): ExchangeQueryStageExec = {
607601
val queryStage = existing.newReuseInstance(currentStageId, exchange.output)
608602
currentStageId += 1
609603
setLogicalLinkForNewQueryStage(queryStage, exchange)
@@ -770,7 +764,7 @@ case class AdaptiveSparkPlanExec(
770764
currentPhysicalPlan.foreach {
771765
// earlyFailedStage is the stage which failed before calling doMaterialize,
772766
// so we should avoid calling cancel on it to re-trigger the failure again.
773-
case s: QueryStageExec if !earlyFailedStage.contains(s.id) =>
767+
case s: ExchangeQueryStageExec if !earlyFailedStage.contains(s.id) =>
774768
try {
775769
s.cancel()
776770
} catch {
@@ -847,8 +841,8 @@ case class AdaptiveExecutionContext(session: SparkSession, qe: QueryExecution) {
847841
/**
848842
* The exchange-reuse map shared across the entire query, including sub-queries.
849843
*/
850-
val stageCache: TrieMap[SparkPlan, ReusableQueryStageExec] =
851-
new TrieMap[SparkPlan, ReusableQueryStageExec]()
844+
val stageCache: TrieMap[SparkPlan, ExchangeQueryStageExec] =
845+
new TrieMap[SparkPlan, ExchangeQueryStageExec]()
852846
}
853847

854848
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,6 @@ abstract class QueryStageExec extends LeafExecNode {
5151
*/
5252
val plan: SparkPlan
5353

54-
/**
55-
* Cancel the stage materialization if in progress; otherwise do nothing.
56-
*/
57-
def cancel(): Unit
58-
5954
/**
6055
* Materialize this query stage, to prepare for the execution, like submitting map stages,
6156
* broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this
@@ -142,13 +137,18 @@ abstract class QueryStageExec extends LeafExecNode {
142137
}
143138

144139
/**
145-
* There are 2 kinds of reusable query stages:
140+
* There are 2 kinds of exchange query stages:
146141
* 1. Shuffle query stage. This stage materializes its output to shuffle files, and Spark launches
147142
* another job to execute the further operators.
148143
* 2. Broadcast query stage. This stage materializes its output to an array in driver JVM. Spark
149144
* broadcasts the array before executing the further operators.
150145
*/
151-
abstract class ReusableQueryStageExec extends QueryStageExec {
146+
abstract class ExchangeQueryStageExec extends QueryStageExec {
147+
148+
/**
149+
* Cancel the stage materialization if in progress; otherwise do nothing.
150+
*/
151+
def cancel(): Unit
152152

153153
/**
154154
* The canonicalized plan before applying query stage optimizer rules.
@@ -157,7 +157,7 @@ abstract class ReusableQueryStageExec extends QueryStageExec {
157157

158158
override def doCanonicalize(): SparkPlan = _canonicalized
159159

160-
def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec
160+
def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec
161161
}
162162

163163
/**
@@ -170,7 +170,7 @@ abstract class ReusableQueryStageExec extends QueryStageExec {
170170
case class ShuffleQueryStageExec(
171171
override val id: Int,
172172
override val plan: SparkPlan,
173-
override val _canonicalized: SparkPlan) extends ReusableQueryStageExec {
173+
override val _canonicalized: SparkPlan) extends ExchangeQueryStageExec {
174174

175175
@transient val shuffle = plan match {
176176
case s: ShuffleExchangeLike => s
@@ -183,7 +183,8 @@ case class ShuffleQueryStageExec(
183183

184184
override protected def doMaterialize(): Future[Any] = shuffleFuture
185185

186-
override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = {
186+
override def newReuseInstance(
187+
newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec = {
187188
val reuse = ShuffleQueryStageExec(
188189
newStageId,
189190
ReusedExchangeExec(newOutput, shuffle),
@@ -221,7 +222,7 @@ case class ShuffleQueryStageExec(
221222
case class BroadcastQueryStageExec(
222223
override val id: Int,
223224
override val plan: SparkPlan,
224-
override val _canonicalized: SparkPlan) extends ReusableQueryStageExec {
225+
override val _canonicalized: SparkPlan) extends ExchangeQueryStageExec {
225226

226227
@transient val broadcast = plan match {
227228
case b: BroadcastExchangeLike => b
@@ -234,7 +235,8 @@ case class BroadcastQueryStageExec(
234235
broadcast.submitBroadcastJob
235236
}
236237

237-
override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = {
238+
override def newReuseInstance(
239+
newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec = {
238240
val reuse = BroadcastQueryStageExec(
239241
newStageId,
240242
ReusedExchangeExec(newOutput, broadcast),
@@ -285,11 +287,5 @@ case class TableCacheQueryStageExec(
285287

286288
override def isMaterialized: Boolean = super.isMaterialized || inMemoryTableScan.isMaterialized
287289

288-
override def cancel(): Unit = {
289-
if (!isMaterialized) {
290-
logDebug(s"Skip canceling the table cache stage: $id")
291-
}
292-
}
293-
294290
override def getRuntimeStatistics: Statistics = inMemoryTableScan.relation.computeStats()
295291
}

0 commit comments

Comments
 (0)