Skip to content

Commit 1a35685

Browse files
maryannxuecloud-fan
andcommitted
[SPARK-37670][SQL] Support predicate pushdown and column pruning for de-duped CTEs
This PR adds predicate push-down and column pruning to CTEs that are not inlined as well as fixes a few potential correctness issues: 1) Replace (previously not inlined) CTE refs with Repartition operations at the end of logical plan optimization so that WithCTE is not carried over to physical plan. As a result, we can simplify the logic of physical planning, as well as avoid a correctness issue where the logical link of a physical plan node can point to `WithCTE` and lead to unexpected behaviors in AQE, e.g., class cast exceptions in DPP. 2) Pull (not inlined) CTE defs from subqueries up to the main query level, in order to avoid creating copies of the same CTE def during predicate push-downs and other transformations. 3) Make CTE IDs more deterministic by starting from 0 for each query. Improve de-duped CTEs' performance with predicate pushdown and column pruning; fixes de-duped CTEs' correctness issues. No. Added UTs. Closes #34929 from maryannxue/cte-followup. Lead-authored-by: Maryann Xue <maryann.xue@gmail.com> Co-authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit 175e429) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 35ec300 commit 1a35685

File tree

22 files changed

+1177
-522
lines changed

22 files changed

+1177
-522
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,13 @@ object CTESubstitution extends Rule[LogicalPlan] {
6666
if (cteDefs.isEmpty) {
6767
substituted
6868
} else if (substituted eq lastSubstituted.get) {
69-
WithCTE(substituted, cteDefs.toSeq)
69+
WithCTE(substituted, cteDefs.sortBy(_.id).toSeq)
7070
} else {
7171
var done = false
7272
substituted.resolveOperatorsWithPruning(_ => !done) {
7373
case p if p eq lastSubstituted.get =>
7474
done = true
75-
WithCTE(p, cteDefs.toSeq)
75+
WithCTE(p, cteDefs.sortBy(_.id).toSeq)
7676
}
7777
}
7878
}
@@ -200,6 +200,7 @@ object CTESubstitution extends Rule[LogicalPlan] {
200200
cteDefs: mutable.ArrayBuffer[CTERelationDef]): Seq[(String, CTERelationDef)] = {
201201
val resolvedCTERelations = new mutable.ArrayBuffer[(String, CTERelationDef)](relations.size)
202202
for ((name, relation) <- relations) {
203+
val lastCTEDefCount = cteDefs.length
203204
val innerCTEResolved = if (isLegacy) {
204205
// In legacy mode, outer CTE relations take precedence. Here we don't resolve the inner
205206
// `With` nodes, later we will substitute `UnresolvedRelation`s with outer CTE relations.
@@ -208,8 +209,33 @@ object CTESubstitution extends Rule[LogicalPlan] {
208209
} else {
209210
// A CTE definition might contain an inner CTE that has a higher priority, so traverse and
210211
// substitute CTE defined in `relation` first.
212+
// NOTE: we must call `traverseAndSubstituteCTE` before `substituteCTE`, as the relations
213+
// in the inner CTE have higher priority over the relations in the outer CTE when resolving
214+
// inner CTE relations. For example:
215+
// WITH t1 AS (SELECT 1)
216+
// t2 AS (
217+
// WITH t1 AS (SELECT 2)
218+
// WITH t3 AS (SELECT * FROM t1)
219+
// )
220+
// t3 should resolve the t1 to `SELECT 2` instead of `SELECT 1`.
211221
traverseAndSubstituteCTE(relation, isCommand, cteDefs)._1
212222
}
223+
224+
if (cteDefs.length > lastCTEDefCount) {
225+
// We have added more CTE relations to the `cteDefs` from the inner CTE, and these relations
226+
// should also be substituted with `resolvedCTERelations` as inner CTE relation can refer to
227+
// outer CTE relation. For example:
228+
// WITH t1 AS (SELECT 1)
229+
// t2 AS (
230+
// WITH t3 AS (SELECT * FROM t1)
231+
// )
232+
for (i <- lastCTEDefCount until cteDefs.length) {
233+
val substituted =
234+
substituteCTE(cteDefs(i).child, isLegacy || isCommand, resolvedCTERelations.toSeq)
235+
cteDefs(i) = cteDefs(i).copy(child = substituted)
236+
}
237+
}
238+
213239
// CTE definition can reference a previous one
214240
val substituted =
215241
substituteCTE(innerCTEResolved, isLegacy || isCommand, resolvedCTERelations.toSeq)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.AnalysisException
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
2424
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
25-
import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, DecorrelateInnerQuery}
25+
import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, DecorrelateInnerQuery, InlineCTE}
2626
import org.apache.spark.sql.catalyst.plans._
2727
import org.apache.spark.sql.catalyst.plans.logical._
2828
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, TypeUtils}
@@ -90,8 +90,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
9090

9191
def checkAnalysis(plan: LogicalPlan): Unit = {
9292
// We transform up and order the rules so as to catch the first possible failure instead
93-
// of the result of cascading resolution failures.
94-
plan.foreachUp {
93+
// of the result of cascading resolution failures. Inline all CTEs in the plan to help check
94+
// query plan structures in subqueries.
95+
val inlineCTE = InlineCTE(alwaysInline = true)
96+
inlineCTE(plan).foreachUp {
9597

9698
case p if p.analyzed => // Skip already analyzed sub-plans
9799

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,37 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION}
2828

2929
/**
3030
* Inlines CTE definitions into corresponding references if either of the conditions satisfies:
31-
* 1. The CTE definition does not contain any non-deterministic expressions. If this CTE
32-
* definition references another CTE definition that has non-deterministic expressions, it
33-
* is still OK to inline the current CTE definition.
31+
* 1. The CTE definition does not contain any non-deterministic expressions or contains attribute
32+
* references to an outer query. If this CTE definition references another CTE definition that
33+
* has non-deterministic expressions, it is still OK to inline the current CTE definition.
3434
* 2. The CTE definition is only referenced once throughout the main query and all the subqueries.
3535
*
36-
* In addition, due to the complexity of correlated subqueries, all CTE references in correlated
37-
* subqueries are inlined regardless of the conditions above.
36+
* CTE definitions that appear in subqueries and are not inlined will be pulled up to the main
37+
* query level.
38+
*
39+
* @param alwaysInline if true, inline all CTEs in the query plan.
3840
*/
39-
object InlineCTE extends Rule[LogicalPlan] {
41+
case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
42+
4043
override def apply(plan: LogicalPlan): LogicalPlan = {
4144
if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) {
4245
val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int)]
4346
buildCTEMap(plan, cteMap)
44-
inlineCTE(plan, cteMap, forceInline = false)
47+
val notInlined = mutable.ArrayBuffer.empty[CTERelationDef]
48+
val inlined = inlineCTE(plan, cteMap, notInlined)
49+
// CTEs in SQL Commands have been inlined by `CTESubstitution` already, so it is safe to add
50+
// WithCTE as top node here.
51+
if (notInlined.isEmpty) {
52+
inlined
53+
} else {
54+
WithCTE(inlined, notInlined.toSeq)
55+
}
4556
} else {
4657
plan
4758
}
4859
}
4960

50-
private def shouldInline(cteDef: CTERelationDef, refCount: Int): Boolean = {
61+
private def shouldInline(cteDef: CTERelationDef, refCount: Int): Boolean = alwaysInline || {
5162
// We do not need to check enclosed `CTERelationRef`s for `deterministic` or `OuterReference`,
5263
// because:
5364
// 1) It is fine to inline a CTE if it references another CTE that is non-deterministic;
@@ -93,25 +104,24 @@ object InlineCTE extends Rule[LogicalPlan] {
93104
private def inlineCTE(
94105
plan: LogicalPlan,
95106
cteMap: mutable.HashMap[Long, (CTERelationDef, Int)],
96-
forceInline: Boolean): LogicalPlan = {
97-
val (stripped, notInlined) = plan match {
107+
notInlined: mutable.ArrayBuffer[CTERelationDef]): LogicalPlan = {
108+
plan match {
98109
case WithCTE(child, cteDefs) =>
99-
val notInlined = mutable.ArrayBuffer.empty[CTERelationDef]
100110
cteDefs.foreach { cteDef =>
101111
val (cte, refCount) = cteMap(cteDef.id)
102112
if (refCount > 0) {
103-
val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, forceInline))
113+
val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, notInlined))
104114
cteMap.update(cteDef.id, (inlined, refCount))
105-
if (!forceInline && !shouldInline(inlined, refCount)) {
115+
if (!shouldInline(inlined, refCount)) {
106116
notInlined.append(inlined)
107117
}
108118
}
109119
}
110-
(inlineCTE(child, cteMap, forceInline), notInlined.toSeq)
120+
inlineCTE(child, cteMap, notInlined)
111121

112122
case ref: CTERelationRef =>
113123
val (cteDef, refCount) = cteMap(ref.cteId)
114-
val newRef = if (forceInline || shouldInline(cteDef, refCount)) {
124+
if (shouldInline(cteDef, refCount)) {
115125
if (ref.outputSet == cteDef.outputSet) {
116126
cteDef.child
117127
} else {
@@ -125,24 +135,16 @@ object InlineCTE extends Rule[LogicalPlan] {
125135
} else {
126136
ref
127137
}
128-
(newRef, Seq.empty)
129138

130139
case _ if plan.containsPattern(CTE) =>
131-
val newPlan = plan
132-
.withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, forceInline)))
140+
plan
141+
.withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, notInlined)))
133142
.transformExpressionsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, CTE)) {
134143
case e: SubqueryExpression =>
135-
e.withNewPlan(inlineCTE(e.plan, cteMap, forceInline = e.isCorrelated))
144+
e.withNewPlan(inlineCTE(e.plan, cteMap, notInlined))
136145
}
137-
(newPlan, Seq.empty)
138146

139-
case _ => (plan, Seq.empty)
140-
}
141-
142-
if (notInlined.isEmpty) {
143-
stripped
144-
} else {
145-
WithCTE(stripped, notInlined)
147+
case _ => plan
146148
}
147149
}
148150
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
126126
OptimizeUpdateFields,
127127
SimplifyExtractValueOps,
128128
OptimizeCsvJsonExprs,
129-
CombineConcats) ++
129+
CombineConcats,
130+
PushdownPredicatesAndPruneColumnsForCTEDef) ++
130131
extendedOperatorOptimizationRules
131132

132133
val operatorOptimizationBatch: Seq[Batch] = {
@@ -145,21 +146,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
145146
}
146147

147148
val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) ::
148-
// Technically some of the rules in Finish Analysis are not optimizer rules and belong more
149-
// in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime).
150-
// However, because we also use the analyzer to canonicalized queries (for view definition),
151-
// we do not eliminate subqueries or compute current time in the analyzer.
152-
Batch("Finish Analysis", Once,
153-
EliminateResolvedHint,
154-
EliminateSubqueryAliases,
155-
EliminateView,
156-
InlineCTE,
157-
ReplaceExpressions,
158-
RewriteNonCorrelatedExists,
159-
PullOutGroupingExpressions,
160-
ComputeCurrentTime,
161-
ReplaceCurrentLike(catalogManager),
162-
SpecialDatetimeValues) ::
149+
Batch("Finish Analysis", Once, FinishAnalysis) ::
163150
//////////////////////////////////////////////////////////////////////////////////////////
164151
// Optimizer rules start here
165152
//////////////////////////////////////////////////////////////////////////////////////////
@@ -168,6 +155,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
168155
// extra operators between two adjacent Union operators.
169156
// - Call CombineUnions again in Batch("Operator Optimizations"),
170157
// since the other rules might make two separate Unions operators adjacent.
158+
Batch("Inline CTE", Once,
159+
InlineCTE()) ::
171160
Batch("Union", Once,
172161
RemoveNoopOperators,
173162
CombineUnions,
@@ -204,6 +193,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
204193
RemoveLiteralFromGroupExpressions,
205194
RemoveRepetitionFromGroupExpressions) :: Nil ++
206195
operatorOptimizationBatch) :+
196+
Batch("Clean Up Temporary CTE Info", Once, CleanUpTempCTEInfo) :+
207197
// This batch rewrites plans after the operator optimization and
208198
// before any batches that depend on stats.
209199
Batch("Pre CBO Rules", Once, preCBORules: _*) :+
@@ -260,14 +250,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
260250
* (defaultBatches - (excludedRules - nonExcludableRules)).
261251
*/
262252
def nonExcludableRules: Seq[String] =
263-
EliminateDistinct.ruleName ::
264-
EliminateResolvedHint.ruleName ::
265-
EliminateSubqueryAliases.ruleName ::
266-
EliminateView.ruleName ::
267-
ReplaceExpressions.ruleName ::
268-
ComputeCurrentTime.ruleName ::
269-
SpecialDatetimeValues.ruleName ::
270-
ReplaceCurrentLike(catalogManager).ruleName ::
253+
FinishAnalysis.ruleName ::
271254
RewriteDistinctAggregates.ruleName ::
272255
ReplaceDeduplicateWithAggregate.ruleName ::
273256
ReplaceIntersectWithSemiJoin.ruleName ::
@@ -281,9 +264,37 @@ abstract class Optimizer(catalogManager: CatalogManager)
281264
RewritePredicateSubquery.ruleName ::
282265
NormalizeFloatingNumbers.ruleName ::
283266
ReplaceUpdateFieldsExpression.ruleName ::
284-
PullOutGroupingExpressions.ruleName ::
285267
RewriteLateralSubquery.ruleName :: Nil
286268

269+
/**
270+
* Apply finish-analysis rules for the entire plan including all subqueries.
271+
*/
272+
object FinishAnalysis extends Rule[LogicalPlan] {
273+
// Technically some of the rules in Finish Analysis are not optimizer rules and belong more
274+
// in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime).
275+
// However, because we also use the analyzer to canonicalized queries (for view definition),
276+
// we do not eliminate subqueries or compute current time in the analyzer.
277+
private val rules = Seq(
278+
EliminateResolvedHint,
279+
EliminateSubqueryAliases,
280+
EliminateView,
281+
ReplaceExpressions,
282+
RewriteNonCorrelatedExists,
283+
PullOutGroupingExpressions,
284+
ComputeCurrentTime,
285+
ReplaceCurrentLike(catalogManager),
286+
SpecialDatetimeValues)
287+
288+
override def apply(plan: LogicalPlan): LogicalPlan = {
289+
rules.foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
290+
.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
291+
case s: SubqueryExpression =>
292+
val Subquery(newPlan, _) = apply(Subquery.fromExpression(s))
293+
s.withNewPlan(newPlan)
294+
}
295+
}
296+
}
297+
287298
/**
288299
* Optimize all the subqueries inside expression.
289300
*/

0 commit comments

Comments
 (0)