Skip to content

Commit bc4fe93

Browse files
committed
Move PullupCorrelatedPredicates and RewritePredicateSubquery after OptimizeSubqueries
This commit moves two rules right next to the rule OptimizeSubqueries. 1. PullupCorrelatedPredicates: the rewrite of [Not] Exists and [Not] In (ListQuery) to PredicateSubquery 2. RewritePredicateSubquery: the rewrite of PredicateSubquery to LeftSemi/LeftAnti With this change, [Not] Exists/In subquery is now rewritten to LeftSemi/LeftAnti at the beginning of Optimizer. By moving rule PullupCorrelatedPredicates after rule OptimizerSubqueries, all the rules from the nested call to the entire Optimizer on the plans in subqueries will need to deal with (1). the correlated columns wrapped with OuterReference, and (2) the SubqueryExpression. We will block any push down of both types of expressions for the following reasons: 1. We do not want to push any correlated expressions further down the plan tree. Deep correlation is not yet supported in Spark, and, even when supported, deep correlation is more difficult to be unnested to a join. 2. We do not want to push any correlated subquery down because the correlated columns' ExprIds in the subquery may need to remap to different ExprIds from the plan below the current Filter that hosts the subquery. Another side effect is we used to push down Exists/In subquery as if it is a predicate in rule PushDownPredicate and rule PushPredicateThroughJoin. Now Exists/In subquery is rewritten to LeftSemi/LeftAnti, we need to handle the push down of LeftSemi/LeftAnti instead. This will be done in a followup commit. Another Todo is to merge the two-stage rewrite in rule PullupCorrelatedPredicates and rule RewritePredicateSubquery into a single stage rewrite.
1 parent 9e1c18c commit bc4fe93

File tree

3 files changed

+40
-33
lines changed

3 files changed

+40
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,10 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
6868
// since the other rules might make two separate Unions operators adjacent.
6969
Batch("Union", Once,
7070
CombineUnions) ::
71-
Batch("Pullup Correlated Expressions", Once,
72-
PullupCorrelatedPredicates) ::
7371
Batch("Subquery", Once,
74-
OptimizeSubqueries) ::
72+
OptimizeSubqueries,
73+
PullupCorrelatedPredicates,
74+
RewritePredicateSubquery) ::
7575
Batch("Replace Operators", fixedPoint,
7676
ReplaceIntersectWithSemiJoin,
7777
ReplaceExceptWithAntiJoin,
@@ -130,10 +130,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
130130
ConvertToLocalRelation,
131131
PropagateEmptyRelation) ::
132132
Batch("OptimizeCodegen", Once,
133-
OptimizeCodegen(conf)) ::
134-
Batch("RewriteSubquery", Once,
135-
RewritePredicateSubquery,
136-
CollapseProject) :: Nil
133+
OptimizeCodegen(conf)) :: Nil
137134
}
138135

139136
/**
@@ -746,7 +743,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
746743
// state and all the input rows processed before. In another word, the order of input rows
747744
// matters for non-deterministic expressions, while pushing down predicates changes the order.
748745
case filter @ Filter(condition, project @ Project(fields, grandChild))
749-
if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) =>
746+
if fields.forall(_.deterministic) &&
747+
!SubqueryExpression.hasCorrelatedSubquery(condition) &&
748+
!SubExprUtils.containsOuter(condition) =>
750749

751750
// Create a map of Aliases to their values from the child projection.
752751
// e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b).
@@ -769,7 +768,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
769768
splitConjunctivePredicates(condition).span(_.deterministic)
770769

771770
val (pushDown, rest) = candidates.partition { cond =>
772-
cond.references.subsetOf(partitionAttrs)
771+
cond.references.subsetOf(partitionAttrs) &&
772+
!SubqueryExpression.hasCorrelatedSubquery(cond) &&
773+
!SubExprUtils.containsOuter(cond)
773774
}
774775

775776
val stayUp = rest ++ containingNonDeterministic
@@ -797,7 +798,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
797798

798799
val (pushDown, rest) = candidates.partition { cond =>
799800
val replaced = replaceAlias(cond, aliasMap)
800-
cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet)
801+
cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet) &&
802+
!SubqueryExpression.hasCorrelatedSubquery(cond) &&
803+
!SubExprUtils.containsOuter(cond)
801804
}
802805

803806
val stayUp = rest ++ containingNonDeterministic
@@ -815,7 +818,14 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
815818

816819
case filter @ Filter(condition, union: Union) =>
817820
// Union could change the rows, so non-deterministic predicate can't be pushed down
818-
val (pushDown, stayUp) = splitConjunctivePredicates(condition).span(_.deterministic)
821+
val (candidates, containingNonDeterministic) =
822+
splitConjunctivePredicates(condition).span(_.deterministic)
823+
824+
val (pushDown, rest) = candidates.partition { cond =>
825+
!SubqueryExpression.hasCorrelatedSubquery(cond) &&
826+
!SubExprUtils.containsOuter(cond)
827+
}
828+
val stayUp = rest ++ containingNonDeterministic
819829

820830
if (pushDown.nonEmpty) {
821831
val pushDownCond = pushDown.reduceLeft(And)
@@ -839,7 +849,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
839849
}
840850

841851
case filter @ Filter(condition, u: UnaryNode)
842-
if canPushThrough(u) && u.expressions.forall(_.deterministic) =>
852+
if canPushThrough(u) && u.expressions.forall(_.deterministic) &&
853+
!SubqueryExpression.hasCorrelatedSubquery(condition) &&
854+
!SubExprUtils.containsOuter(condition) =>
843855
pushDownPredicate(filter, u.child) { predicate =>
844856
u.withNewChildren(Seq(Filter(predicate, u.child)))
845857
}
@@ -887,20 +899,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
887899
filter
888900
}
889901
}
890-
891-
/**
892-
* Check if we can safely push a filter through a projection, by making sure that predicate
893-
* subqueries in the condition do not contain the same attributes as the plan they are moved
894-
* into. This can happen when the plan and predicate subquery have the same source.
895-
*/
896-
private def canPushThroughCondition(plan: LogicalPlan, condition: Expression): Boolean = {
897-
val attributes = plan.outputSet
898-
val matched = condition.find {
899-
case s: SubqueryExpression => s.plan.outputSet.intersect(attributes).nonEmpty
900-
case _ => false
901-
}
902-
matched.isEmpty
903-
}
904902
}
905903

906904
/**
@@ -927,13 +925,18 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
927925
// any deterministic expression that follows a non-deterministic expression. To achieve this,
928926
// we only consider pushing down those expressions that precede the first non-deterministic
929927
// expression in the condition.
930-
val (pushDownCandidates, containingNonDeterministic) = condition.span(_.deterministic)
928+
val (candidates, containingNonDeterministic) = condition.span(_.deterministic)
929+
val (pushDownCandidates, subquery) = candidates.partition { cond =>
930+
!SubqueryExpression.hasCorrelatedSubquery(cond) &&
931+
!SubExprUtils.containsOuter(cond)
932+
}
931933
val (leftEvaluateCondition, rest) =
932934
pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
933935
val (rightEvaluateCondition, commonCondition) =
934936
rest.partition(expr => expr.references.subsetOf(right.outputSet))
935937

936-
(leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ containingNonDeterministic)
938+
(leftEvaluateCondition, rightEvaluateCondition,
939+
subquery ++ commonCondition ++ containingNonDeterministic)
937940
}
938941

939942
def apply(plan: LogicalPlan): LogicalPlan = plan transform {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,11 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
445445
* Returns whether the expression returns null or false when all inputs are nulls.
446446
*/
447447
private def canFilterOutNull(e: Expression): Boolean = {
448-
if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false
448+
if (!e.deterministic ||
449+
SubqueryExpression.hasCorrelatedSubquery(e) ||
450+
SubExprUtils.containsOuter(e)) {
451+
return false
452+
}
449453
val attributes = e.references.toSeq
450454
val emptyRow = new GenericInternalRow(attributes.length)
451455
val boundE = BindReferences.bindReference(e, attributes)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,7 @@ class FilterPushdownSuite extends PlanTest {
799799
comparePlans(optimized, correctedAnswer)
800800
}
801801

802-
test("predicate subquery: push down simple") {
802+
test("correlated subquery (simple): no push down") {
803803
val x = testRelation.subquery('x)
804804
val y = testRelation.subquery('y)
805805
val z = LocalRelation('a.int, 'b.int, 'c.int).subquery('z)
@@ -809,14 +809,14 @@ class FilterPushdownSuite extends PlanTest {
809809
.where(Exists(z.where("x.a".attr === "z.a".attr)))
810810
.analyze
811811
val answer = x
812-
.where(Exists(z.where("x.a".attr === "z.a".attr)))
813812
.join(y, Inner, Option("x.a".attr === "y.a".attr))
813+
.where(Exists(z.where("x.a".attr === "z.a".attr)))
814814
.analyze
815815
val optimized = Optimize.execute(Optimize.execute(query))
816816
comparePlans(optimized, answer)
817817
}
818818

819-
test("predicate subquery: push down complex") {
819+
test("correlated subquery (complex): no push down") {
820820
val w = testRelation.subquery('w)
821821
val x = testRelation.subquery('x)
822822
val y = testRelation.subquery('y)
@@ -828,9 +828,9 @@ class FilterPushdownSuite extends PlanTest {
828828
.where(Exists(z.where("w.a".attr === "z.a".attr)))
829829
.analyze
830830
val answer = w
831-
.where(Exists(z.where("w.a".attr === "z.a".attr)))
832831
.join(x, Inner, Option("w.a".attr === "x.a".attr))
833832
.join(y, LeftOuter, Option("x.a".attr === "y.a".attr))
833+
.where(Exists(z.where("w.a".attr === "z.a".attr)))
834834
.analyze
835835
val optimized = Optimize.execute(Optimize.execute(query))
836836
comparePlans(optimized, answer)

0 commit comments

Comments
 (0)