From b98865127a39bde885f9b1680cfe608629d59d51 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Fri, 29 Jul 2016 17:43:56 -0400 Subject: [PATCH 01/18] [SPARK-16804][SQL] Correlated subqueries containing LIMIT return incorrect results ## What changes were proposed in this pull request? This patch fixes the incorrect results in the rule ResolveSubquery in Catalyst's Analysis phase. ## How was this patch tested? ./dev/run-tests a new unit test on the problematic pattern. --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 10 ++++++++++ .../sql/catalyst/analysis/AnalysisErrorSuite.scala | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2efa997ff22d..c3ee6517875c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1021,6 +1021,16 @@ class Analyzer( case e: Expand => failOnOuterReferenceInSubTree(e, "an EXPAND") e + case l @ LocalLimit(_, child) => + failOnOuterReferenceInSubTree(l, "LIMIT") + l + // Since LIMIT is represented as GlobalLimit(, (LocalLimit (, child)) + // and we are walking bottom up, we will fail on LocalLimit before + // reaching GlobalLimit. + // The code below is just a safety net. + case g @ GlobalLimit(_, child) => + failOnOuterReferenceInSubTree(g, "LIMIT") + g case p => failOnOuterReference(p) p diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index ff112c51697a..b78a988eddbb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -533,5 +533,13 @@ class AnalysisErrorSuite extends AnalysisTest { Exists(Union(LocalRelation(b), Filter(EqualTo(OuterReference(a), c), LocalRelation(c)))), LocalRelation(a)) assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil) + + val plan4 = Filter( + Exists( + Limit(1, + Filter(EqualTo(OuterReference(a), b), LocalRelation(b))) + ), + LocalRelation(a)) + assertAnalysisError(plan4, "Accessing outer query column is not allowed in LIMIT" :: Nil) } } From 069ed8f8e5f14dca7a15701945d42fc27fe82f3c Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Fri, 29 Jul 2016 17:50:02 -0400 Subject: [PATCH 02/18] [SPARK-16804][SQL] Correlated subqueries containing LIMIT return incorrect results ## What changes were proposed in this pull request? This patch fixes the incorrect results in the rule ResolveSubquery in Catalyst's Analysis phase. ## How was this patch tested? ./dev/run-tests a new unit test on the problematic pattern. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c3ee6517875c..357c763f5946 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1022,14 +1022,14 @@ class Analyzer( failOnOuterReferenceInSubTree(e, "an EXPAND") e case l @ LocalLimit(_, child) => - failOnOuterReferenceInSubTree(l, "LIMIT") + failOnOuterReferenceInSubTree(l, "a LIMIT") l // Since LIMIT is represented as GlobalLimit(, (LocalLimit (, child)) // and we are walking bottom up, we will fail on LocalLimit before // reaching GlobalLimit. // The code below is just a safety net. case g @ GlobalLimit(_, child) => - failOnOuterReferenceInSubTree(g, "LIMIT") + failOnOuterReferenceInSubTree(g, "a LIMIT") g case p => failOnOuterReference(p) From edca333c081e6d4e53a91b496fba4a3ef4ee89ac Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Fri, 29 Jul 2016 20:28:15 -0400 Subject: [PATCH 03/18] New positive test cases --- .../org/apache/spark/sql/SubquerySuite.scala | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index afed342ff8e2..52387b4b72a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -571,4 +571,33 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) } + + test("SPARK-16804: Correlated subqueries containing LIMIT - 1") { + withTempView("onerow") { + Seq(1).toDF("c1").createOrReplaceTempView("onerow") + + checkAnswer( + sql( + """ + | select c1 from onerow t1 + | where exists (select 1 from onerow t2 where t1.c1=t2.c1) + | and exists (select 1 from onerow LIMIT 1)""".stripMargin), + Row(1) :: Nil) + } + } + + test("SPARK-16804: Correlated subqueries containing LIMIT - 2") { + withTempView("onerow") { + Seq(1).toDF("c1").createOrReplaceTempView("onerow") + + checkAnswer( + sql( + """ + | select c1 from onerow t1 + | where exists (select 1 + | from (select 1 from onerow t2 LIMIT 1) + | where t1.c1=t2.c1)""".stripMargin), + Row(1) :: Nil) + } + } } From 64184fdb77c1a305bb2932e82582da28bb4c0e53 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Mon, 1 Aug 2016 09:20:09 -0400 Subject: [PATCH 04/18] Fix unit test case failure --- .../apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index b78a988eddbb..c08de826bd94 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -540,6 +540,6 @@ class AnalysisErrorSuite extends AnalysisTest { Filter(EqualTo(OuterReference(a), b), LocalRelation(b))) ), LocalRelation(a)) - assertAnalysisError(plan4, "Accessing outer query column is not allowed in LIMIT" :: Nil) + assertAnalysisError(plan4, "Accessing outer query column is not allowed in a LIMIT" :: Nil) } } From 29f82b05c9e40e7934397257c674b260a8e8a996 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Fri, 5 Aug 2016 13:42:01 -0400 Subject: [PATCH 05/18] blocking TABLESAMPLE --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 7 +++++-- .../spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 8 ++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 357c763f5946..9d99c4173d4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1021,16 +1021,19 @@ class Analyzer( case e: Expand => failOnOuterReferenceInSubTree(e, "an EXPAND") e - case l @ LocalLimit(_, child) => + case l @ LocalLimit(_, _) => failOnOuterReferenceInSubTree(l, "a LIMIT") l // Since LIMIT is represented as GlobalLimit(, (LocalLimit (, child)) // and we are walking bottom up, we will fail on LocalLimit before // reaching GlobalLimit. // The code below is just a safety net. - case g @ GlobalLimit(_, child) => + case g @ GlobalLimit(_, _) => failOnOuterReferenceInSubTree(g, "a LIMIT") g + case s @ Sample(_, _, _, _, _) => + failOnOuterReferenceInSubTree(s, "a TABLESAMPLE") + s case p => failOnOuterReference(p) p diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index c08de826bd94..0b7d681be511 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -541,5 +541,13 @@ class AnalysisErrorSuite extends AnalysisTest { ), LocalRelation(a)) assertAnalysisError(plan4, "Accessing outer query column is not allowed in a LIMIT" :: Nil) + + val plan5 = Filter( + Exists( + Sample(0.0, 0.5, false, 1L, + Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))().select('b) + ), + LocalRelation(a)) + assertAnalysisError(plan5, "Accessing outer query column is not allowed in a TABLESAMPLE" :: Nil) } } From ac43ab47907a1ccd6d22f920415fbb4de93d4720 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Fri, 5 Aug 2016 17:10:19 -0400 Subject: [PATCH 06/18] Fixing code styling --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9d99c4173d4a..29ede7048a2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1021,17 +1021,17 @@ class Analyzer( case e: Expand => failOnOuterReferenceInSubTree(e, "an EXPAND") e - case l @ LocalLimit(_, _) => + case l : LocalLimit => failOnOuterReferenceInSubTree(l, "a LIMIT") l // Since LIMIT is represented as GlobalLimit(, (LocalLimit (, child)) // and we are walking bottom up, we will fail on LocalLimit before // reaching GlobalLimit. // The code below is just a safety net. - case g @ GlobalLimit(_, _) => + case g : GlobalLimit => failOnOuterReferenceInSubTree(g, "a LIMIT") g - case s @ Sample(_, _, _, _, _) => + case s : Sample => failOnOuterReferenceInSubTree(s, "a TABLESAMPLE") s case p => From 631d396031e8bf627eb1f4872a4d3a17c144536c Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Sun, 7 Aug 2016 14:39:44 -0400 Subject: [PATCH 07/18] Correcting Scala test style --- .../spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 0b7d681be511..8935d979414a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -548,6 +548,7 @@ class AnalysisErrorSuite extends AnalysisTest { Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))().select('b) ), LocalRelation(a)) - assertAnalysisError(plan5, "Accessing outer query column is not allowed in a TABLESAMPLE" :: Nil) + assertAnalysisError(plan5, + "Accessing outer query column is not allowed in a TABLESAMPLE" :: Nil) } } From 7eb9b2dbba3633a1958e38e0019e3ce816300514 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Sun, 7 Aug 2016 22:31:09 -0400 Subject: [PATCH 08/18] One (last) attempt to correct the Scala style tests --- .../apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 8935d979414a..6438065fb292 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -548,7 +548,7 @@ class AnalysisErrorSuite extends AnalysisTest { Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))().select('b) ), LocalRelation(a)) - assertAnalysisError(plan5, + assertAnalysisError(plan5, "Accessing outer query column is not allowed in a TABLESAMPLE" :: Nil) } } From bc4fe9326e3c33954d223746ec36fb990fb8d994 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Wed, 22 Mar 2017 19:10:17 -0400 Subject: [PATCH 09/18] Move PullupCorrelatedPredicates and RewritePredicateSubquery after OptimizeSubqueries This commit moves two rules right next to the rule OptimizeSubqueries. 1. PullupCorrelatedPredicates: the rewrite of [Not] Exists and [Not] In (ListQuery) to PredicateSubquery 2. RewritePredicateSubquery: the rewrite of PredicateSubquery to LeftSemi/LeftAnti With this change, [Not] Exists/In subquery is now rewritten to LeftSemi/LeftAnti at the beginning of Optimizer. By moving rule PullupCorrelatedPredicates after rule OptimizerSubqueries, all the rules from the nested call to the entire Optimizer on the plans in subqueries will need to deal with (1). the correlated columns wrapped with OuterReference, and (2) the SubqueryExpression. We will block any push down of both types of expressions for the following reasons: 1. We do not want to push any correlated expressions further down the plan tree. Deep correlation is not yet supported in Spark, and, even when supported, deep correlation is more difficult to be unnested to a join. 2. We do not want to push any correlated subquery down because the correlated columns' ExprIds in the subquery may need to remap to different ExprIds from the plan below the current Filter that hosts the subquery. Another side effect is we used to push down Exists/In subquery as if it is a predicate in rule PushDownPredicate and rule PushPredicateThroughJoin. Now Exists/In subquery is rewritten to LeftSemi/LeftAnti, we need to handle the push down of LeftSemi/LeftAnti instead. This will be done in a followup commit. Another Todo is to merge the two-stage rewrite in rule PullupCorrelatedPredicates and rule RewritePredicateSubquery into a single stage rewrite. --- .../sql/catalyst/optimizer/Optimizer.scala | 59 ++++++++++--------- .../spark/sql/catalyst/optimizer/joins.scala | 6 +- .../optimizer/FilterPushdownSuite.scala | 8 +-- 3 files changed, 40 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d7524a57adbc..c08853983f4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -68,10 +68,10 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: - Batch("Pullup Correlated Expressions", Once, - PullupCorrelatedPredicates) :: Batch("Subquery", Once, - OptimizeSubqueries) :: + OptimizeSubqueries, + PullupCorrelatedPredicates, + RewritePredicateSubquery) :: Batch("Replace Operators", fixedPoint, ReplaceIntersectWithSemiJoin, ReplaceExceptWithAntiJoin, @@ -130,10 +130,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) ConvertToLocalRelation, PropagateEmptyRelation) :: Batch("OptimizeCodegen", Once, - OptimizeCodegen(conf)) :: - Batch("RewriteSubquery", Once, - RewritePredicateSubquery, - CollapseProject) :: Nil + OptimizeCodegen(conf)) :: Nil } /** @@ -746,7 +743,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // state and all the input rows processed before. In another word, the order of input rows // matters for non-deterministic expressions, while pushing down predicates changes the order. case filter @ Filter(condition, project @ Project(fields, grandChild)) - if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) => + if fields.forall(_.deterministic) && + !SubqueryExpression.hasCorrelatedSubquery(condition) && + !SubExprUtils.containsOuter(condition) => // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). @@ -769,7 +768,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { splitConjunctivePredicates(condition).span(_.deterministic) val (pushDown, rest) = candidates.partition { cond => - cond.references.subsetOf(partitionAttrs) + cond.references.subsetOf(partitionAttrs) && + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) } val stayUp = rest ++ containingNonDeterministic @@ -797,7 +798,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { val (pushDown, rest) = candidates.partition { cond => val replaced = replaceAlias(cond, aliasMap) - cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet) + cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet) && + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) } val stayUp = rest ++ containingNonDeterministic @@ -815,7 +818,14 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { case filter @ Filter(condition, union: Union) => // Union could change the rows, so non-deterministic predicate can't be pushed down - val (pushDown, stayUp) = splitConjunctivePredicates(condition).span(_.deterministic) + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(condition).span(_.deterministic) + + val (pushDown, rest) = candidates.partition { cond => + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } + val stayUp = rest ++ containingNonDeterministic if (pushDown.nonEmpty) { val pushDownCond = pushDown.reduceLeft(And) @@ -839,7 +849,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } case filter @ Filter(condition, u: UnaryNode) - if canPushThrough(u) && u.expressions.forall(_.deterministic) => + if canPushThrough(u) && u.expressions.forall(_.deterministic) && + !SubqueryExpression.hasCorrelatedSubquery(condition) && + !SubExprUtils.containsOuter(condition) => pushDownPredicate(filter, u.child) { predicate => u.withNewChildren(Seq(Filter(predicate, u.child))) } @@ -887,20 +899,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } } - - /** - * Check if we can safely push a filter through a projection, by making sure that predicate - * subqueries in the condition do not contain the same attributes as the plan they are moved - * into. This can happen when the plan and predicate subquery have the same source. - */ - private def canPushThroughCondition(plan: LogicalPlan, condition: Expression): Boolean = { - val attributes = plan.outputSet - val matched = condition.find { - case s: SubqueryExpression => s.plan.outputSet.intersect(attributes).nonEmpty - case _ => false - } - matched.isEmpty - } } /** @@ -927,13 +925,18 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { // any deterministic expression that follows a non-deterministic expression. To achieve this, // we only consider pushing down those expressions that precede the first non-deterministic // expression in the condition. - val (pushDownCandidates, containingNonDeterministic) = condition.span(_.deterministic) + val (candidates, containingNonDeterministic) = condition.span(_.deterministic) + val (pushDownCandidates, subquery) = candidates.partition { cond => + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } val (leftEvaluateCondition, rest) = pushDownCandidates.partition(_.references.subsetOf(left.outputSet)) val (rightEvaluateCondition, commonCondition) = rest.partition(expr => expr.references.subsetOf(right.outputSet)) - (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ containingNonDeterministic) + (leftEvaluateCondition, rightEvaluateCondition, + subquery ++ commonCondition ++ containingNonDeterministic) } def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 58e4a230f4ef..5a2ff2104d13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -445,7 +445,11 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { * Returns whether the expression returns null or false when all inputs are nulls. */ private def canFilterOutNull(e: Expression): Boolean = { - if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false + if (!e.deterministic || + SubqueryExpression.hasCorrelatedSubquery(e) || + SubExprUtils.containsOuter(e)) { + return false + } val attributes = e.references.toSeq val emptyRow = new GenericInternalRow(attributes.length) val boundE = BindReferences.bindReference(e, attributes) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 6feea4060f46..7f0b31736a3b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -799,7 +799,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctedAnswer) } - test("predicate subquery: push down simple") { + test("correlated subquery (simple): no push down") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) val z = LocalRelation('a.int, 'b.int, 'c.int).subquery('z) @@ -809,14 +809,14 @@ class FilterPushdownSuite extends PlanTest { .where(Exists(z.where("x.a".attr === "z.a".attr))) .analyze val answer = x - .where(Exists(z.where("x.a".attr === "z.a".attr))) .join(y, Inner, Option("x.a".attr === "y.a".attr)) + .where(Exists(z.where("x.a".attr === "z.a".attr))) .analyze val optimized = Optimize.execute(Optimize.execute(query)) comparePlans(optimized, answer) } - test("predicate subquery: push down complex") { + test("correlated subquery (complex): no push down") { val w = testRelation.subquery('w) val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -828,9 +828,9 @@ class FilterPushdownSuite extends PlanTest { .where(Exists(z.where("w.a".attr === "z.a".attr))) .analyze val answer = w - .where(Exists(z.where("w.a".attr === "z.a".attr))) .join(x, Inner, Option("w.a".attr === "x.a".attr)) .join(y, LeftOuter, Option("x.a".attr === "y.a".attr)) + .where(Exists(z.where("w.a".attr === "z.a".attr))) .analyze val optimized = Optimize.execute(Optimize.execute(query)) comparePlans(optimized, answer) From 208f384f7f6cc85c326940bdc3d8b976a9e5b119 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Wed, 22 Mar 2017 19:55:05 -0400 Subject: [PATCH 10/18] This commit works on 3 things: 1. Make PushDownPrecidate aware of LeftSemi/LeftAnti 2. Add new rule PUshLeftSemiAntiThroughJoin 3. Extend EliminateOuterJoin to deal with LeftSemi/LeftAnti --- .../sql/catalyst/expressions/subquery.scala | 11 + .../sql/catalyst/optimizer/Optimizer.scala | 306 +++++++- .../spark/sql/catalyst/optimizer/joins.scala | 27 + .../spark/sql/catalyst/plans/joinTypes.scala | 7 + .../optimizer/FilterPushdownSuite.scala | 19 +- .../sql/LeftSemiOrAntiPushdownSuite.scala | 661 ++++++++++++++++++ 6 files changed, 1019 insertions(+), 12 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/LeftSemiOrAntiPushdownSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 59db28d58afc..3299162334bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -245,6 +245,17 @@ object ScalarSubquery { case _ => false }.isDefined } + + def hasScalarSubquery(e: Expression): Boolean = { + e.find { + case s: ScalarSubquery => true + case _ => false + }.isDefined + } + + def hasScalarSubquery(e: Seq[Expression]): Boolean = { + e.find(hasScalarSubquery(_)).isDefined + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c846c58ecc18..ace3159944c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -79,6 +79,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) ReorderJoin(conf), EliminateOuterJoin(conf), PushPredicateThroughJoin, + PushLeftSemiLeftAntiThroughJoin, PushDownPredicate, LimitPushDown(conf), ColumnPruning, @@ -397,9 +398,10 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper * Attempts to eliminate the reading of unneeded columns from the query plan. * * Since adding Project before Filter conflicts with PushPredicatesThroughProject, this rule will - * remove the Project p2 in the following pattern: + * remove the Project p2 in the following patterns: * * p1 @ Project(_, Filter(_, p2 @ Project(_, child))) if p2.outputSet.subsetOf(p2.inputSet) + * p1 @ Project(_, j @ Join(p2 @ Project(_, child), _, LeftSemiOrAnti(_), _)) * * p2 is usually inserted by this rule and useless, p1 could prune the columns anyway. */ @@ -499,13 +501,16 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** - * The Project before Filter is not necessary but conflict with PushPredicatesThroughProject, - * so remove it. + * The Project before Filter or LeftSemi/LeftAnti is not necessary + * but conflict with PushPredicatesThroughProject, so remove it. */ private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transform { case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child))) if p2.outputSet.subsetOf(child.outputSet) => p1.copy(child = f.copy(child = child)) + case p1 @ Project(_, j @ Join(p2 @ Project(_, child), _, LeftSemiOrAnti(_), _)) + if p2.outputSet.subsetOf(child.outputSet) => + p1.copy(child = j.copy(left = child)) } } @@ -765,6 +770,42 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) + // Similar to the above Filter over Project + // LeftSemi/LeftAnti over Project + case join @ Join(project @ Project(projectList, grandChild), rightOp, + LeftSemiOrAnti(joinType), joinCond) + if !grandChild.isInstanceOf[LeafNode] && + !ScalarSubquery.hasScalarSubquery(projectList) && + projectList.forall(_.deterministic) => + + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). + var projectListAfterUnalias = AttributeSet.empty + val aliasMap = AttributeMap(projectList.collect { + case a: Alias => + projectListAfterUnalias ++= a.child.references + (a.toAttribute, a.child) + }) + + // If nothing to map from Join to the Project below + // stop the push down + val simple = grandChild match { + case Filter(_, l: LeafNode) => true + case _ => false + } + if (joinCond.isDefined && + // detect potential self-join after pushdown + joinCond.get.references.intersect(projectListAfterUnalias).isEmpty && + (aliasMap.nonEmpty || !simple)) { + val cond = if (joinCond.isDefined) { + Option(replaceAlias(joinCond.get, aliasMap)) + } else None + val res = Project(projectList, Join(grandChild, rightOp, joinType, cond)) + res + } else { + join + } + // Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be // pushed beneath must satisfy the following conditions: // 1. All the expressions are part of window partitioning key. The expressions can be compound. @@ -793,6 +834,35 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } + // Similar to the above Filter over Window + // LeftSemi/LeftAnti over Window + case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond) + if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => + val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++ rightOp.outputSet + + + val (candidates, containingNonDeterministic) = if (joinCond.isDefined) { + splitConjunctivePredicates(joinCond.get).span(_.deterministic) + } else { + (Nil, Nil) + } + + val (pushDown, rest) = candidates.partition { cond => + cond.references.subsetOf(partitionAttrs) && + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } + + val stayUp = rest ++ containingNonDeterministic + + if (pushDown.nonEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val newWindow = w.copy(child = Join(w.child, rightOp, joinType, Option(pushDownPredicate))) + if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And), newWindow) + } else { + join + } + case filter @ Filter(condition, aggregate: Aggregate) => // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression, and create a map from the alias to the expression @@ -826,6 +896,54 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } + // Similar to the above Filter over Aggregate + // LeftSemi/LeftAnti over Aggregate + case join @ Join(aggregate: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond) => + // Find all the aliased expressions in the aggregate list that don't include any actual + // AggregateExpression, and create a map from the alias to the expression + // TODO: detect potential self-join after push down??? + var projectListAfterUnalias = AttributeSet.empty + val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { + case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => + projectListAfterUnalias ++= a.child.references + (a.toAttribute, a.child) + }) + + // For each join condition, expand the alias and check if the condition can be evaluated using + // attributes produced by the aggregate operator's child operator. + val (candidates, containingNonDeterministic) = if (joinCond.isDefined) { + splitConjunctivePredicates(joinCond.get).span(_.deterministic) + } else { + (Nil, Nil) + } + + val (pushDown, rest) = candidates.partition { cond => + val replaced = replaceAlias(cond, aliasMap) + cond.references.nonEmpty && + replaced.references.subsetOf(aggregate.child.outputSet ++ rightOp.outputSet) && + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } + + val stayUp = rest ++ containingNonDeterministic + + if (pushDown.nonEmpty) { + val pushDownPredicate = pushDown.reduce(And) + if (pushDownPredicate.references.intersect(projectListAfterUnalias).isEmpty) { + val replaced = replaceAlias(pushDownPredicate, aliasMap) + val newAggregate = aggregate.copy(child = + Join(aggregate.child, rightOp, joinType, Option(replaced))) + // If there is no more filter to stay up, just return the Aggregate over Join. + // Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)". + if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate) + } + else { + join + } + } else { + join + } + case filter @ Filter(condition, union: Union) => // Union could change the rows, so non-deterministic predicate can't be pushed down val (candidates, containingNonDeterministic) = @@ -858,6 +976,42 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } + // Similar to the above Filter over Union + // LeftSemi/LeftAnti over Union + case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond) => + // Union could change the rows, so non-deterministic predicate can't be pushed down + val (candidates, containingNonDeterministic) = if (joinCond.isDefined) { + splitConjunctivePredicates(joinCond.get).span(_.deterministic) + } else { + (Nil, Nil) + } + val (pushDown, rest) = candidates.partition { cond => + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } + val stayUp = rest ++ containingNonDeterministic + + if (pushDown.nonEmpty) { + val pushDownCond = pushDown.reduceLeft(And) + val output = union.output + val newGrandChildren = union.children.map { grandchild => + val newCond = pushDownCond transform { + case e if output.exists(_.semanticEquals(e)) => + grandchild.output(output.indexWhere(_.semanticEquals(e))) + } + assert(newCond.references.subsetOf(grandchild.outputSet ++ rightOp.outputSet)) + Join(grandchild, rightOp, joinType, Option(newCond)) + } + val newUnion = union.withNewChildren(newGrandChildren) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newUnion) + } else { + newUnion + } + } else { + join + } + case filter @ Filter(condition, u: UnaryNode) if canPushThrough(u) && u.expressions.forall(_.deterministic) && !SubqueryExpression.hasCorrelatedSubquery(condition) && @@ -865,6 +1019,14 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { pushDownPredicate(filter, u.child) { predicate => u.withNewChildren(Seq(Filter(predicate, u.child))) } + + // Similar to the above Filter over UnaryNode + // LeftSemi/LeftAnti over UnaryNode + case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond) + if canPushThrough(u) => + pushDownJoin(join, u.child) { joinCond => + u.withNewChildren(Seq(Join(u.child, rightOp, joinType, Option(joinCond)))) + } } private def canPushThrough(p: UnaryNode): Boolean = p match { @@ -893,7 +1055,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { splitConjunctivePredicates(filter.condition).span(_.deterministic) val (pushDown, rest) = candidates.partition { cond => - cond.references.subsetOf(grandchild.outputSet) + cond.references.subsetOf(grandchild.outputSet) && + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) } val stayUp = rest ++ containingNonDeterministic @@ -909,6 +1073,37 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } } + + private def pushDownJoin( + join: Join, + grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { + // Only push down the predicates that is deterministic and all the referenced attributes + // come from grandchild. + val (candidates, containingNonDeterministic) = if (join.condition.isDefined) { + splitConjunctivePredicates(join.condition.get).span(_.deterministic) + } else { + (Nil, Nil) + } + + val (pushDown, rest) = candidates.partition { cond => + cond.references.subsetOf(grandchild.outputSet ++ join.right.outputSet) && + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } + + val stayUp = rest ++ containingNonDeterministic + + if (pushDown.nonEmpty) { + val newChild = insertFilter(pushDown.reduceLeft(And)) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + join + } + } } /** @@ -1034,6 +1229,109 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Pushes down a subquery, in the form of [[Join LeftSemi/LeftAnti]] operator + * to the left or right side of a join below. + */ +object PushLeftSemiLeftAntiThroughJoin extends Rule[LogicalPlan] with PredicateHelper { + /** + * Define an enumeration to identify whether a Exists/In subquery, + * in the form of a LeftSemi/LeftAnti, can be pushed down to + * the left table or the right table. + */ + object subqueryPushdown extends Enumeration { + val toRightTable, toLeftTable, none = Value + } + + /** + * Determine which side of the join an Exists/In subquery (in the form of + * LeftSemi/LeftAnti join) can be pushed down to. + */ + private def pushTo(child: Join, subquery: LogicalPlan, joinCond: Option[Expression]) = { + val left = child.left + val right = child.right + val joinType = child.joinType + val subqueryOutput = subquery.outputSet + + if (joinCond.nonEmpty) { + /** + * Note: In order to ensure correctness, it's important to not change the relative ordering of + * any deterministic expression that follows a non-deterministic expression. To achieve this, + * we only consider pushing down those expressions that precede the first non-deterministic + * expression in the condition. + */ + val noPushdown = (subqueryPushdown.none, None) + val conditions = splitConjunctivePredicates(joinCond.get) + val (candidates, containingNonDeterministic) = conditions.span(_.deterministic) + lazy val (pushDownCandidates, subquery) = + candidates.partition { cond => + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } + lazy val (leftConditions, rest) = + pushDownCandidates.partition(_.references.subsetOf(left.outputSet ++ subqueryOutput)) + lazy val (rightConditions, commonConditions) = + rest.partition(_.references.subsetOf(right.outputSet ++ subqueryOutput)) + + if (containingNonDeterministic.nonEmpty || subquery.nonEmpty) { + noPushdown + } else { + if (rest.isEmpty && leftConditions.nonEmpty) { + // When all the join conditions are only between left table and the subquery + // push the subquery to the left table. + (subqueryPushdown.toLeftTable, leftConditions.reduceLeftOption(And)) + } else if (leftConditions.isEmpty && rightConditions.nonEmpty && commonConditions.isEmpty) { + // When all the join conditions are only between right table and the subquery + // push the subquery to the right table. + (subqueryPushdown.toRightTable, rightConditions.reduceLeftOption(And)) + } else { + noPushdown + } + } + } else { + /** + * When there is no correlated predicate, + * 1) if this is a left outer join, push the subquery down to the left table + * 2) if a right outer join, to the right table, + * 3) if an inner join, push to either side. + */ + val action = joinType match { + case RightOuter => + subqueryPushdown.toRightTable + case _: InnerLike | LeftOuter => + subqueryPushdown.toLeftTable + case _ => + subqueryPushdown.none + } + (action, None) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // push LeftSemi/LeftAnti down into the join below + case j @ Join(child @ Join(left, right, _ : InnerLike | LeftOuter | RightOuter, belowJoinCond), + subquery, LeftSemiOrAnti(joinType), joinCond) => + val belowJoinType = child.joinType + val (action, newJoinCond) = pushTo(child, subquery, joinCond) + + action match { + case subqueryPushdown.toLeftTable + if (belowJoinType == LeftOuter || belowJoinType.isInstanceOf[InnerLike]) => + // push down the subquery to the left table + val newLeft = Join(left, subquery, joinType, newJoinCond) + Join(newLeft, right, belowJoinType, belowJoinCond) + case subqueryPushdown.toRightTable + if (belowJoinType == RightOuter || belowJoinType.isInstanceOf[InnerLike]) => + // push down the subquery to the right table + val newRight = Join(right, subquery, joinType, newJoinCond) + Join(left, newRight, belowJoinType, belowJoinCond) + case _ => + // Do nothing + j + } + } +} + /** * Combines two adjacent [[Limit]] operators into one, merging the * expressions into one single expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 130353c6bc90..dc12892c0364 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -478,9 +478,36 @@ case class EliminateOuterJoin(conf: CatalystConf) extends Rule[LogicalPlan] with } } + private def buildNewJoinType(joinCond: Expression, join: Join, subqueryOutPut: AttributeSet): + JoinType = { + val conditions = splitConjunctivePredicates(joinCond) + val leftConditions = conditions.filter(_.references. + subsetOf(join.left.outputSet ++ subqueryOutPut)) + val rightConditions = conditions.filter(_.references. + subsetOf(join.right.outputSet ++ subqueryOutPut)) + + val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) + val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) + + join.joinType match { + case RightOuter if leftHasNonNullPredicate => Inner + case LeftOuter if rightHasNonNullPredicate => Inner + case FullOuter if leftHasNonNullPredicate && rightHasNonNullPredicate => Inner + case FullOuter if leftHasNonNullPredicate => LeftOuter + case FullOuter if rightHasNonNullPredicate => RightOuter + case o => o + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _)) => val newJoinType = buildNewJoinType(f, j) if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) + case j @ Join(child @ Join(_, _, RightOuter | LeftOuter | FullOuter, _), + subquery, LeftSemiOrAnti(joinType), joinCond) => + if (joinCond.isDefined) { + val newJoinType = buildNewJoinType(joinCond.get, child, subquery.outputSet) + Join(child.copy(joinType = newJoinType), subquery, joinType, joinCond) + } else j } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 818f4e5ed2ae..bebc79248e50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -112,3 +112,10 @@ object LeftExistence { case _ => None } } + +object LeftSemiOrAnti { + def unapply(joinType: JoinType): Option[JoinType] = joinType match { + case LeftSemi | LeftAnti => Some(joinType) + case _ => None + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 3ab7a9577646..e472d922ca20 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -38,6 +38,7 @@ class FilterPushdownSuite extends PlanTest { PushDownPredicate, BooleanSimplification, PushPredicateThroughJoin, + PushLeftSemiLeftAntiThroughJoin, CollapseProject) :: Nil } @@ -799,7 +800,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctedAnswer) } - test("correlated subquery (simple): no push down") { + test("predicate subquery: push down simple") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) val z = LocalRelation('a.int, 'b.int, 'c.int).subquery('z) @@ -809,14 +810,15 @@ class FilterPushdownSuite extends PlanTest { .where(Exists(z.where("x.a".attr === "z.a".attr))) .analyze val answer = x - .join(y, Inner, Option("x.a".attr === "y.a".attr)) .where(Exists(z.where("x.a".attr === "z.a".attr))) + .join(y, Inner, Option("x.a".attr === "y.a".attr)) .analyze - val optimized = Optimize.execute(Optimize.execute(query)) - comparePlans(optimized, answer) + val optimized = Optimize.execute(query) + val expected = Optimize.execute(answer) + comparePlans(optimized, expected) } - test("correlated subquery (complex): no push down") { + test("predicate subquery: push down complex") { val w = testRelation.subquery('w) val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -828,12 +830,13 @@ class FilterPushdownSuite extends PlanTest { .where(Exists(z.where("w.a".attr === "z.a".attr))) .analyze val answer = w + .where(Exists(z.where("w.a".attr === "z.a".attr))) .join(x, Inner, Option("w.a".attr === "x.a".attr)) .join(y, LeftOuter, Option("x.a".attr === "y.a".attr)) - .where(Exists(z.where("w.a".attr === "z.a".attr))) .analyze - val optimized = Optimize.execute(Optimize.execute(query)) - comparePlans(optimized, answer) + val optimized = Optimize.execute(query) + val expected = Optimize.execute(answer) + comparePlans(optimized, expected) } test("SPARK-20094: don't push predicate with IN subquery into join condition") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LeftSemiOrAntiPushdownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LeftSemiOrAntiPushdownSuite.scala new file mode 100644 index 000000000000..a1c35ca2a078 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/LeftSemiOrAntiPushdownSuite.scala @@ -0,0 +1,661 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +/* + * Writing test cases using combinatorial testing technique + * Dimension 1: (A) Exists or (B) In + * Dimension 2: (A) LeftSemi, (B) LeftAnti, or (C) ExistenceJoin + * Dimension 3: (A) Join over Project, (B) Join over Agg, (C) Join over Window, + * (D) Join over Union, or (E) Join over other UnaryNode + * Dimension 4: (A) join condition is column or (B) expression + * Dimension 5: Subquery is (A) a single table, or (B) more than one table + * Dimension 6: Parent side is (A) a single table, or (B) more than one table + */ +class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + // setupTestData() + + val row = identity[(java.lang.Integer, java.lang.Integer, java.lang.Integer)](_) + + lazy val t1 = Seq( + row(1, 1, 1), + row(1, 2, 2), + row(2, 1, null), + row(3, 1, 2), + row(null, 0, 3), + row(4, null, 2), + row(0, -1, null)).toDF("t1a", "t1b", "t1c") + + lazy val t2 = Seq( + row(1, 1, 1), + row(2, 1, 1), + row(2, 1, null), + row(3, 3, 3), + row(3, 1, 0), + row(null, null, 1), + row(0, 0, -1)).toDF("t2a", "t2b", "t2c") + + lazy val t3 = Seq( + row(1, 1, 1), + row(2, 1, 0), + row(2, 1, null), + row(10, 4, -1), + row(3, 2, 0), + row(-2, 1, -1), + row(null, null, null)).toDF("t3a", "t3b", "t3c") + + lazy val t4 = Seq( + row(1, 1, 2), + row(1, 2, 1), + row(2, 1, null)).toDF("t4a", "t4b", "t4c") + + lazy val t5 = Seq( + row(1, 1, 1), + row(2, null, 0), + row(2, 1, null)).toDF("t5a", "t5b", "t5c") + + protected override def beforeAll(): Unit = { + super.beforeAll() + t1.createOrReplaceTempView("t1") + t2.createOrReplaceTempView("t2") + t3.createOrReplaceTempView("t3") + t4.createOrReplaceTempView("t4") + t5.createOrReplaceTempView("t5") + } + + /** + * TC 1.1: 1A-2B-3A-4B-5A-6A + * Expected result: LeftAnti below Project + * Note that the expression T1A+1 is evaluated twice in Join and Project + * + * TC 1.1.1: Comparing to Inner, we do not push down Inner join under Project + * + * SELECT TX.* + * FROM (SELECT T1A+1 T1A1, T1B + * FROM T1 + * WHERE T1A > 2) TX, T2 + * WHERE T2A = T1A1 + */ + test("TC 1.1: LeftSemi/LeftAnti over Project") { + val plan1 = + sql( + """ + | select * + | from (select t1a+1 t1a1, t1b + | from t1 + | where t1a > 2) tx + | where not exists (select 1 + | from t2 + | where t2a = t1a1) + """.stripMargin) + val plan2 = + sql( + """ + | select t1a+1 t1a1, t1b + | from t1 + | where t1a > 2 + | and not exists (select 1 + | from t2 + | where t2a = t1a+1) + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + + /** + * TC 1.2: 1B-2A-3B-4B-5B-6A + * Expected result: LeftSemi below Aggregate + */ + test("TC 1.2: LeftSemi/LeftAnti over Aggregate") { + val plan1 = + sql( + """ + | select * + | from (select sum(t1a), coalesce(t1c, 0) t1c_expr + | from t1 + | group by coalesce(t1c, 0)) tx + | where t1c_expr in (select t2b + | from t2, t3 + | where t2a = t3a) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select sum(t1a), coalesce(t1c, 0) t1c_expr + | from t1 + | where coalesce(t1c, 0) in (select t2b + | from t2, t3 + | where t2a = t3a) + | group by coalesce(t1c, 0)) tx + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + + /** + * TC 1.3: 1A-2A-3C-4B-5A-6A + * Expected result: LeftSemi below Window + * + * Variations that yield no push down + * + * TC 1.3.1: We do not match T1B1 to the expression T1B+1 in the PARTITION BY clause + * hence no push down. + * + * SELECT * + * FROM (SELECT T1B+1 as T1B1, SUM(T1B * T1A) OVER (PARTITION BY T1B+1) SUM + * FROM T1) TX + * WHERE EXISTS (SELECT 1 FROM T2 WHERE T2B = TX.T1B1) + * + * TC 1.3.2: With the additional column Exists from the ExistenceJoin that does not exist + * in Window, and we do not add a compensation, the result is + * we don't push down ExistenceJoin under a Window. + * + * SELECT * + * FROM (SELECT T1B, SUM(T1B * T1A) OVER (PARTITION BY T1B) SUM + * FROM T1) TX + * WHERE EXISTS (SELECT 1 FROM T2 WHERE T2B = TX.T1B) + * OR T1B1 > 1 + */ + test("TC 1.3: LeftSemi/LeftAnti over Window") { + val plan1 = + sql( + """ + | select * + | from (select t1b, sum(t1b * t1a) over (partition by t1b) sum + | from t1) tx + | where exists (select 1 + | from t2 + | where t2b = tx.t1b) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select t1b, sum(t1b * t1a) over (partition by t1b) sum + | from t1 + | where exists (select 1 + | from t2 + | where t2b = t1.t1b)) tx + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + + /** + * TC 1.4: 1B-2B-3D-4A-5B-6B + * Expected result: LeftAnti below Union + */ + test("TC 1.4: LeftSemi/LeftAnti over Union") { + val plan1 = + sql( + """ + | select * + | from (select t1a, t1b, t1c + | from t1, t3 + | where t1a = t3a + | union all + | select t2a, t2b, t2c + | from t2, t3 + | where t2a = t3a) ua + | where t1c not in (select t4c + | from t5, t4 + | where t5.t5b = t4.t4b) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select t1a, t1b, t1c + | from t1, t3 + | where t1a = t3a + | and t1c not in (select t4c + | from t5, t4 + | where t5.t5b = t4.t4b) + | union all + | select t2a, t2b, t2c + | from t2, t3 + | where t2a = t3a + | and t2c not in (select t4c + | from t5, t4 + | where t5.t5b = t4.t4b) + | ) ua + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + + /** + * TC 1.5: 1B-2B-3E-4B-5A-6B + * Expected result: LeftAnti below Sort + */ + test("TC 1.5: LeftSemi/LeftAnti over other UnaryNode") { + val plan1 = + sql( + """ + | select * + | from (select t1a+1 t1a1, t1b, t3c + | from t1, t3 + | where t1b = t3b + | and t1a < 3 + | order by t1b) tx + | where tx.t1a1 not in (select t2a + | from t2 + | where t2b < 3 + | and tx.t3c >= 0) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select t1a+1 t1a1, t1b, t3c + | from t1, t3 + | where t1b = t3b + | and t1a < 3 + | and t1.t1a+1 not in (select t2a + | from t2 + | where t2b < 3 + | and t3c >= 0) + | order by t1b) tx + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + + /** + * LeftSemi/LeftAnti over join + * + * Dimension 1: (A) LeftSemi or (B) LeftAnti + * Dimension 2: Join below is (A) Inner (B) LeftOuter (C) RightOuter (D) FullOuter, or, + * (E) LeftSemi/LeftAnti + * Dimension 3: Subquery correlated to (A) left table (B) right table, (C) both tables, + * or, (D) no correlated predicate + */ + /** + * TC 2.1: 1A-2A-3A + * Expected result: LeftSemi join below Inner join + */ + test("TC 2.1: LeftSemi over inner join") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 inner join t2 on t1b = t2b and t2a >= 2) + | select * + | from join + | where t1a in (select t3a from t3 where t3b >= 1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select * + | from t1 + | where t1a in (select t3a from t3 where t3b >= 1)) t1 + | inner join t2 + | on t1b = t2b and t2a >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.2: 1A-2B-3A + * Expected result: LeftSemi join below LeftOuter join + */ + test("TC 2.2: LeftSemi over left outer join with correlated columns on the left table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 left join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where exists (select 1 from t3 where t3a = t1a and t3b >= 1) + """. + stripMargin) + val plan2 = + sql( + """ + | select * + | from (select * + | from t1 + | where exists (select 1 from t3 where t3a = t1a and t3b >= 1)) t1 + | left join t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.3: 1B-2B-3A + * Expected result: LeftAnti join below LeftOuter join + */ + test("TC 2.3: LeftAnti over left outer join with correlated columns on the left table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 left join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where not exists (select 1 from t3 where t3a = t1a and t3b >= 1) + """. + stripMargin) + val plan2 = + sql( + """ + | select * + | from (select * + | from t1 + | where not exists (select 1 from t3 where t3a = t1a and t3b >= 1)) t1 + | left join t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.4: 1A-2C-3A + * Expected result: LeftSemi join below Inner join + */ + test("TC 2.4: LeftSemi over right outer join with correlated columns on the left table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 right join t2 on t1b = t2b and t2c is null) + | select * + | from join + | where exists (select 1 from t3 where t3a = t1a and t3b >= 1) + """. + stripMargin) + val plan2 = + sql( + """ + | select * + | from (select * + | from t1 + | where exists (select 1 from t3 where t3a = t1a and t3b >= 1)) t1 + | inner join t2 + | on t1b = t2b and t2c is null + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.5: 1B-2C-3B + * Expected result: LeftAnti join below RightOuter join + * RightOuter does not convert to Inner because NOT IN can return null. + */ + test("TC 2.5: LeftAnti over right outer join with correlated columns on the right table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 right join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where t2a not in (select t3a from t3 where t3b >= 1) + """. + stripMargin) + val plan2 = + sql( + """ + | select * + | from t1 + | right join + | (select * + | from t2 + | where t2a not in (select t3a from t3 where t3b >= 1)) t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.6: 1B-2C-3C + * Expected result: No push down + */ + test("TC 2.6: LeftAnti over right outer join with correlated cols on both left and right tbls") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 right join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where not exists (select 1 from t3 where t3a = t1a and t3b > t2b) + """. + stripMargin) + val plan2 = + sql( + """ + | with join as + | (select * from t1 right join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | left anti join + | (select t3a, t3b + | from t3 + | where t3a is not null + | and t3b is not null) t3 + | on t3a = t1a and t3b > t2b + """.stripMargin) + checkAnswer(plan1, plan2) + // comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.7: 1B-2D-3A + * Expected result: LeftAnti join below LeftOuter join + */ + test("TC 2.7: LeftAnti over full outer join with correlated columns on the left table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 full join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where not exists (select 1 from t3 where t3a = t1a and t3b >= 1) + """. + stripMargin) + val plan2 = + sql( + """ + | select * + | from (select * + | from t1 + | where not exists (select 1 from t3 where t3a = t1a and t3b >= 1)) t1 + | left join t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.8: 1A-2D-3B + * Expected result: LeftSemi join below RightOuter join + */ + test("TC 2.8: LeftSemi over full outer join with correlated columns on the right table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 full join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where exists (select 1 from t3 where t3a = t2a and t3b >= 1) + """. + stripMargin) + val plan2 = + sql( + """ + | select * + | from t1 + | right join + | (select * + | from t2 + | where exists (select 1 from t3 where t3a = t2a and t3b >= 1)) t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.9: 1A-2E-3A + * Expected result: No push down + */ + test("TC 2.9: LeftSemi over left semi join with correlated columns on the left table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 left semi join t2 on t1b = t2b and t2c >= 0) + | select * + | from join + | where exists (select 1 from t3 where t3a = t1a and t3c is not null) + """. + stripMargin) + val plan2 = + sql( + """ + | with join as + | (select * + | from t1 + | left semi join t2 + | on t1b = t2b and t2c >= 0) + | select * + | from join + | left semi join t3 + | on t3a = t1a and t3c is not null + """.stripMargin) + checkAnswer(plan1, plan2) + // comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.10: 1A-2A-3C + * Expected result: No push down + */ + test("TC 2.10: LeftSemi over inner join with correlated columns on both left and right tables") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 inner join t2 on t1b = t2b and t2c is null) + | select * + | from join + | where exists (select 1 from t3 where t3a = t1a and t3a = t2a) + """. + stripMargin) + val plan2 = + sql( + """ + | with join as + | (select * + | from t1 + | inner join t2 + | on t1b = t2b and t2c is null) + | select * + | from join + | left semi join t3 + | on t3a = t1a and t3a = t2a + """.stripMargin) + checkAnswer(plan1, plan2) + // comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.11: 1B-2C-3D + * Expected result: LeftSemi join below RightOuter join + */ + test("TC 2.11: LeftAnti over right outer join with no correlated columns") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 right join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where not exists (select 1 from t3 where t3b < -1) + """. + stripMargin) + val plan2 = + sql( + """ + | select * + | from t1 + | right outer join + | (select * + | from t2 + | where not exists (select 1 from t3 where t3b < -1)) t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } + /** + * TC 2.12: 1B-2D-3D + * Expected result: LeftSemi join below RightOuter join + */ + test("TC 2.12: LeftAnti over full outer join with no correlated columns") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 full join t2 on t1b = t2b and t2c >= 0) + | select * + | from join + | where not exists (select 1 from t3 where t3b < -1) + | and (t1c = 1 or t1c is null) + """. + stripMargin) + val plan2 = + sql( + """ + | with join as + | (select * from t1 full join t2 on t1b = t2b and t2c >= 0) + | select * + | from join + | left anti join t3 + | on t3b < -1 + | where (t1c = 1 or t1c is null) + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + plan1.show + } +} From a86f18b99e7b47c4edc2b4ac602405deba1b706c Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Thu, 23 Mar 2017 17:27:17 -0400 Subject: [PATCH 11/18] Add LeftSemi/LeftAnti's constraints --- .../scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../sql/catalyst/plans/logical/basicLogicalOperators.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 2d8ec2053a4c..2e6aeb6ac18b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -35,7 +35,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT .union(inferAdditionalConstraints(constraints)) .union(constructIsNotNullConstraints(constraints)) .filter(constraint => - constraint.references.nonEmpty && constraint.references.subsetOf(outputSet) && + constraint.references.nonEmpty && /* constraint.references.subsetOf(outputSet) && */ constraint.deterministic) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 19db42c80895..b76c1d6b37b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -319,7 +319,7 @@ case class Join( left.constraints .union(right.constraints) .union(splitConjunctivePredicates(condition.get).toSet) - case LeftSemi if condition.isDefined => + case LeftSemi | LeftAnti if condition.isDefined => left.constraints .union(splitConjunctivePredicates(condition.get).toSet) case j: ExistenceJoin => From fe89f35d0f5d488450265f5aac4c0e607460f704 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Fri, 24 Mar 2017 11:59:21 -0400 Subject: [PATCH 12/18] Revert back QueryPlan.scala and fix FilterPushdownSuite --- .../scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../spark/sql/catalyst/optimizer/FilterPushdownSuite.scala | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 2e6aeb6ac18b..2d8ec2053a4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -35,7 +35,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT .union(inferAdditionalConstraints(constraints)) .union(constructIsNotNullConstraints(constraints)) .filter(constraint => - constraint.references.nonEmpty && /* constraint.references.subsetOf(outputSet) && */ + constraint.references.nonEmpty && constraint.references.subsetOf(outputSet) && constraint.deterministic) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index e472d922ca20..4a55db9053c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -33,6 +33,8 @@ class FilterPushdownSuite extends PlanTest { val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: + Batch("Subquery", Once, + RewritePredicateSubquery) :: Batch("Filter Pushdown", FixedPoint(10), CombineFilters, PushDownPredicate, From f078309bcd14f306253a522e63d1db754425c38e Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Sat, 25 Mar 2017 06:27:54 -0400 Subject: [PATCH 13/18] Clean up and add LeftSemi/Anti pushdown on empty joinCond --- .../sql/catalyst/optimizer/Optimizer.scala | 212 ++++++++++-------- 1 file changed, 114 insertions(+), 98 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ace3159944c2..6d5eb351f0a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -743,6 +743,7 @@ case class PruneFilters(conf: CatalystConf) extends Rule[LogicalPlan] with Predi } } + /** * Pushes [[Filter]] operators through many operators iff: * 1) the operator is deterministic @@ -774,36 +775,34 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // LeftSemi/LeftAnti over Project case join @ Join(project @ Project(projectList, grandChild), rightOp, LeftSemiOrAnti(joinType), joinCond) - if !grandChild.isInstanceOf[LeafNode] && + if projectList.forall(_.deterministic) && !ScalarSubquery.hasScalarSubquery(projectList) && - projectList.forall(_.deterministic) => - - // Create a map of Aliases to their values from the child projection. - // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). - var projectListAfterUnalias = AttributeSet.empty - val aliasMap = AttributeMap(projectList.collect { - case a: Alias => - projectListAfterUnalias ++= a.child.references - (a.toAttribute, a.child) - }) + canPushThroughCondition(grandChild, joinCond) => - // If nothing to map from Join to the Project below - // stop the push down + // If this is over a simple Project, stop the push down val simple = grandChild match { + case _: LeafNode => true case Filter(_, l: LeafNode) => true case _ => false } - if (joinCond.isDefined && - // detect potential self-join after pushdown - joinCond.get.references.intersect(projectListAfterUnalias).isEmpty && - (aliasMap.nonEmpty || !simple)) { - val cond = if (joinCond.isDefined) { - Option(replaceAlias(joinCond.get, aliasMap)) - } else None - val res = Project(projectList, Join(grandChild, rightOp, joinType, cond)) - res - } else { + if (simple) { + // No push down join + } else if (joinCond.isEmpty) { + // No join condition, just push down the Join below Project + Project(projectList, Join(grandChild, rightOp, joinType, joinCond)) + } else { + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). + val aliasMap = AttributeMap(projectList.collect { + case a: Alias => (a.toAttribute, a.child) + }) + val newJoinCond = if (aliasMap.nonEmpty) { + Option(replaceAlias(joinCond.get, aliasMap)) + } else { + joinCond + } + Project(projectList, Join(grandChild, rightOp, joinType, newJoinCond)) } // Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be @@ -838,29 +837,33 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // LeftSemi/LeftAnti over Window case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond) if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => - val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++ rightOp.outputSet - + if (joinCond.isEmpty) { + // No join condition, just push down Join below Window + w.copy(child = Join(w.child, rightOp, joinType, joinCond)) + } else { + val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++ + rightOp.outputSet - val (candidates, containingNonDeterministic) = if (joinCond.isDefined) { + val (candidates, containingNonDeterministic) = splitConjunctivePredicates(joinCond.get).span(_.deterministic) - } else { - (Nil, Nil) - } - val (pushDown, rest) = candidates.partition { cond => - cond.references.subsetOf(partitionAttrs) && - !SubqueryExpression.hasCorrelatedSubquery(cond) && - !SubExprUtils.containsOuter(cond) - } + val (pushDown, rest) = candidates.partition { cond => + cond.references.subsetOf(partitionAttrs) && + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } - val stayUp = rest ++ containingNonDeterministic + val stayUp = rest ++ containingNonDeterministic - if (pushDown.nonEmpty) { - val pushDownPredicate = pushDown.reduce(And) - val newWindow = w.copy(child = Join(w.child, rightOp, joinType, Option(pushDownPredicate))) - if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And), newWindow) - } else { - join + if (pushDown.nonEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val newPlan = w.copy(child = Join(w.child, rightOp, joinType, Option(pushDownPredicate))) + if (stayUp.isEmpty) newPlan else Filter(stayUp.reduce(And), newPlan) + } else { + // The join condition is not a subset of the Window's PARTITION BY clause, + // no push down. + join + } } case filter @ Filter(condition, aggregate: Aggregate) => @@ -899,49 +902,46 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // Similar to the above Filter over Aggregate // LeftSemi/LeftAnti over Aggregate case join @ Join(aggregate: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond) => - // Find all the aliased expressions in the aggregate list that don't include any actual - // AggregateExpression, and create a map from the alias to the expression - // TODO: detect potential self-join after push down??? - var projectListAfterUnalias = AttributeSet.empty - val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { - case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => - projectListAfterUnalias ++= a.child.references - (a.toAttribute, a.child) - }) - - // For each join condition, expand the alias and check if the condition can be evaluated using - // attributes produced by the aggregate operator's child operator. - val (candidates, containingNonDeterministic) = if (joinCond.isDefined) { - splitConjunctivePredicates(joinCond.get).span(_.deterministic) + if (joinCond.isEmpty) { + // No join condition, just push down Join below Aggregate + aggregate.copy(child = Join(aggregate.child, rightOp, joinType, joinCond)) } else { - (Nil, Nil) - } + // Find all the aliased expressions in the aggregate list that don't include any actual + // AggregateExpression, and create a map from the alias to the expression + lazy val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { + case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => + (a.toAttribute, a.child) + }) - val (pushDown, rest) = candidates.partition { cond => - val replaced = replaceAlias(cond, aliasMap) - cond.references.nonEmpty && - replaced.references.subsetOf(aggregate.child.outputSet ++ rightOp.outputSet) && - !SubqueryExpression.hasCorrelatedSubquery(cond) && - !SubExprUtils.containsOuter(cond) - } + // For each join condition, expand the alias and + // check if the condition can be evaluated using + // attributes produced by the aggregate operator's child operator. + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(joinCond.get).span(_.deterministic) - val stayUp = rest ++ containingNonDeterministic + val (pushDown, rest) = candidates.partition { cond => + val replaced = replaceAlias(cond, aliasMap) + cond.references.nonEmpty && + replaced.references.subsetOf(aggregate.child.outputSet ++ rightOp.outputSet) && + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) + } - if (pushDown.nonEmpty) { - val pushDownPredicate = pushDown.reduce(And) - if (pushDownPredicate.references.intersect(projectListAfterUnalias).isEmpty) { + val stayUp = rest ++ containingNonDeterministic + + if (pushDown.nonEmpty) { + val pushDownPredicate = pushDown.reduce(And) val replaced = replaceAlias(pushDownPredicate, aliasMap) val newAggregate = aggregate.copy(child = Join(aggregate.child, rightOp, joinType, Option(replaced))) // If there is no more filter to stay up, just return the Aggregate over Join. // Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)". if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate) - } - else { + } else { + // The join condition is not a subset of the Aggregate's GROUP BY columns, + // no push down. join } - } else { - join } case filter @ Filter(condition, union: Union) => @@ -979,37 +979,40 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // Similar to the above Filter over Union // LeftSemi/LeftAnti over Union case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond) => - // Union could change the rows, so non-deterministic predicate can't be pushed down - val (candidates, containingNonDeterministic) = if (joinCond.isDefined) { - splitConjunctivePredicates(joinCond.get).span(_.deterministic) - } else { - (Nil, Nil) + if (joinCond.isEmpty) { + // Push down the Join below Union + val newGrandChildren = union.children.map { grandchild => + Join(grandchild, rightOp, joinType, joinCond) } - val (pushDown, rest) = candidates.partition { cond => - !SubqueryExpression.hasCorrelatedSubquery(cond) && - !SubExprUtils.containsOuter(cond) - } - val stayUp = rest ++ containingNonDeterministic + union.withNewChildren(newGrandChildren) + } else { + // Union could change the rows, so non-deterministic predicate can't be pushed down + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(joinCond.get).span(_.deterministic) - if (pushDown.nonEmpty) { - val pushDownCond = pushDown.reduceLeft(And) - val output = union.output - val newGrandChildren = union.children.map { grandchild => - val newCond = pushDownCond transform { - case e if output.exists(_.semanticEquals(e)) => - grandchild.output(output.indexWhere(_.semanticEquals(e))) - } - assert(newCond.references.subsetOf(grandchild.outputSet ++ rightOp.outputSet)) - Join(grandchild, rightOp, joinType, Option(newCond)) + val (pushDown, rest) = candidates.partition { cond => + !SubqueryExpression.hasCorrelatedSubquery(cond) && + !SubExprUtils.containsOuter(cond) } - val newUnion = union.withNewChildren(newGrandChildren) - if (stayUp.nonEmpty) { - Filter(stayUp.reduceLeft(And), newUnion) + val stayUp = rest ++ containingNonDeterministic + + if (pushDown.nonEmpty) { + val pushDownCond = pushDown.reduceLeft(And) + val output = union.output + val newGrandChildren = union.children.map { grandchild => + val newCond = pushDownCond transform { + case e if output.exists(_.semanticEquals(e)) => + grandchild.output(output.indexWhere(_.semanticEquals(e))) + } + assert(newCond.references.subsetOf(grandchild.outputSet ++ rightOp.outputSet)) + Join(grandchild, rightOp, joinType, Option(newCond)) + } + val newUnion = union.withNewChildren(newGrandChildren) + if (stayUp.isEmpty) newUnion else Filter(stayUp.reduceLeft(And), newUnion) } else { - newUnion + // Nothing to push down + join } - } else { - join } case filter @ Filter(condition, u: UnaryNode) @@ -1029,6 +1032,19 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } } + /** + * Check if we can safely push a filter through a projection, by making sure that predicate + * subqueries in the condition do not contain the same attributes as the plan they are moved + * into. This can happen when the plan and predicate subquery have the same source. + */ + private def canPushThroughCondition(plan: LogicalPlan, condition: Option[Expression]): Boolean = { + val attributes = plan.outputSet + if (condition.isDefined) { + val matched = condition.get.references.intersect(attributes) + matched.isEmpty + } else true + } + private def canPushThrough(p: UnaryNode): Boolean = p match { // Note that some operators (e.g. project, aggregate, union) are being handled separately // (earlier in this rule). From 2479bcdd478f4dcb2bf261e5665c49ca25af347f Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Sat, 25 Mar 2017 08:12:00 -0400 Subject: [PATCH 14/18] Fix bug in Join over Project that breaks LeftSemiOrAntiPushdownSuite TC 1.3 --- .../sql/catalyst/optimizer/Optimizer.scala | 40 ++++++++++--------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6d5eb351f0a0..462da5d6e864 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -775,20 +775,11 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // LeftSemi/LeftAnti over Project case join @ Join(project @ Project(projectList, grandChild), rightOp, LeftSemiOrAnti(joinType), joinCond) - if projectList.forall(_.deterministic) && + if !tooSimplePlan(grandChild) && + projectList.forall(_.deterministic) && !ScalarSubquery.hasScalarSubquery(projectList) && - canPushThroughCondition(grandChild, joinCond) => - - // If this is over a simple Project, stop the push down - val simple = grandChild match { - case _: LeafNode => true - case Filter(_, l: LeafNode) => true - case _ => false - } - if (simple) { - // No push down - join - } else if (joinCond.isEmpty) { + canPushThroughCondition(grandChild, joinCond, rightOp) => + if (joinCond.isEmpty) { // No join condition, just push down the Join below Project Project(projectList, Join(grandChild, rightOp, joinType, joinCond)) } else { @@ -1032,15 +1023,26 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } } + private def tooSimplePlan(plan: LogicalPlan) : Boolean = { + // If this is over a simple Project, stop the push down + plan match { + case _: LeafNode => true + case Filter(_, l: LeafNode) => true + case _ => false + } + } + /** - * Check if we can safely push a filter through a projection, by making sure that predicate - * subqueries in the condition do not contain the same attributes as the plan they are moved - * into. This can happen when the plan and predicate subquery have the same source. - */ - private def canPushThroughCondition(plan: LogicalPlan, condition: Option[Expression]): Boolean = { + * TODO: Update comment + * Check if we can safely push a join through a projection, by making sure that predicate + * subqueries in the condition do not contain the same attributes as the plan they are moved + * into. This can happen when the plan and predicate subquery have the same source. + */ + private def canPushThroughCondition(plan: LogicalPlan, condition: Option[Expression], + rightOp: LogicalPlan): Boolean = { val attributes = plan.outputSet if (condition.isDefined) { - val matched = condition.get.references.intersect(attributes) + val matched = condition.get.references.intersect(rightOp.outputSet).intersect(attributes) matched.isEmpty } else true } From bb8fad94340ee1409ad9be53572304af1c265510 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Mon, 3 Apr 2017 15:30:29 -0400 Subject: [PATCH 15/18] Update IN subquery pushdown test case --- .../spark/sql/catalyst/optimizer/FilterPushdownSuite.scala | 5 +++-- .../org/apache/spark/sql/catalyst/plans/PlanTest.scala | 7 ++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 4a55db9053c8..d6348a76bdba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -852,13 +852,14 @@ class FilterPushdownSuite extends PlanTest { ("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr))))) .analyze - val expectedPlan = x + val answer = x .join(z, Inner, Some("x.b".attr === "z.b".attr)) .where("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr)))) .analyze val optimized = Optimize.execute(queryPlan) - comparePlans(optimized, expectedPlan) + val expected = Optimize.execute(answer) + comparePlans(optimized, expected) } test("Window: predicate push down -- basic") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index c73dfaf3f8fe..869d5a075427 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ /** * Provides helper methods for comparing plans. @@ -71,7 +72,11 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { val newCondition = splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode()) .reduce(And) - Join(left, right, joinType, Some(newCondition)) + val maskedJoinType = if (joinType.isInstanceOf[ExistenceJoin]) { + val exists = AttributeReference("exists", BooleanType, false)(exprId = ExprId(0)) + ExistenceJoin(exists) + } else joinType + Join(left, right, maskedJoinType, Some(newCondition)) } } From 4aaab02b6fa384c51aef8484255f7a51097842be Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Mon, 3 Apr 2017 16:24:04 -0400 Subject: [PATCH 16/18] Fix merge conflict --- .../spark/sql/catalyst/optimizer/joins.scala | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index dc12892c0364..e64bc7b147a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -478,18 +478,24 @@ case class EliminateOuterJoin(conf: CatalystConf) extends Rule[LogicalPlan] with } } - private def buildNewJoinType(joinCond: Expression, join: Join, subqueryOutPut: AttributeSet): + private def buildNewJoinType(upperJoin: Join, lowerJoin: Join, otherTableOutput: AttributeSet): JoinType = { - val conditions = splitConjunctivePredicates(joinCond) + val conditions = upperJoin.constraints + // Find the predicates reference only on the other table. + val localConditions = conditions.filter(_.references.subsetOf(otherTableOutput)) + // Find the predicates reference either the left table or the join predicates + // between the left table and the other table. val leftConditions = conditions.filter(_.references. - subsetOf(join.left.outputSet ++ subqueryOutPut)) + subsetOf(lowerJoin.left.outputSet ++ otherTableOutput)).diff(localConditions) + // Find the predicates reference either the right table or the join predicates + // between the right table and the other table. val rightConditions = conditions.filter(_.references. - subsetOf(join.right.outputSet ++ subqueryOutPut)) + subsetOf(lowerJoin.right.outputSet ++ otherTableOutput)).diff(localConditions) val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) - join.joinType match { + lowerJoin.joinType match { case RightOuter if leftHasNonNullPredicate => Inner case LeftOuter if rightHasNonNullPredicate => Inner case FullOuter if leftHasNonNullPredicate && rightHasNonNullPredicate => Inner @@ -505,9 +511,9 @@ case class EliminateOuterJoin(conf: CatalystConf) extends Rule[LogicalPlan] with if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) case j @ Join(child @ Join(_, _, RightOuter | LeftOuter | FullOuter, _), subquery, LeftSemiOrAnti(joinType), joinCond) => - if (joinCond.isDefined) { - val newJoinType = buildNewJoinType(joinCond.get, child, subquery.outputSet) + val newJoinType = buildNewJoinType(j, child, subquery.outputSet) + if (newJoinType == child.joinType) j else { Join(child.copy(joinType = newJoinType), subquery, joinType, joinCond) - } else j + } } } From 0bab4fd335279accca5e90ed4ecdb1d7ea99383e Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Mon, 3 Apr 2017 20:26:47 -0400 Subject: [PATCH 17/18] Fix test failure HiveCompatibilitySuite/subquery_in_having --- .../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 462da5d6e864..bf01442ad0ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -899,7 +899,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } else { // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression, and create a map from the alias to the expression - lazy val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { + val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => (a.toAttribute, a.child) }) @@ -920,7 +920,10 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { val stayUp = rest ++ containingNonDeterministic - if (pushDown.nonEmpty) { + // Make sure that the remaining predicate does not contain subquery's columns + val nonPushDown = rest.flatMap(_.references).intersect(rightOp.output) + + if (pushDown.nonEmpty && nonPushDown.isEmpty) { val pushDownPredicate = pushDown.reduce(And) val replaced = replaceAlias(pushDownPredicate, aliasMap) val newAggregate = aggregate.copy(child = From 2081fac401cd744b9f4ee78e3c862fb6589a9b16 Mon Sep 17 00:00:00 2001 From: Nattavut Sutyanyong Date: Tue, 4 Apr 2017 15:54:04 -0400 Subject: [PATCH 18/18] Handle Aggregate/Window/Union under LeftSemi/Anti and new test cases --- .../sql/catalyst/optimizer/Optimizer.scala | 16 +- .../sql/LeftSemiOrAntiPushdownSuite.scala | 178 +++++++++++++++--- 2 files changed, 162 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bf01442ad0ff..26a309e04763 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -846,7 +846,10 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { val stayUp = rest ++ containingNonDeterministic - if (pushDown.nonEmpty) { + // Check if the remaining predicates do not contain columns from subquery + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) + + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { val pushDownPredicate = pushDown.reduce(And) val newPlan = w.copy(child = Join(w.child, rightOp, joinType, Option(pushDownPredicate))) if (stayUp.isEmpty) newPlan else Filter(stayUp.reduce(And), newPlan) @@ -920,10 +923,10 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { val stayUp = rest ++ containingNonDeterministic - // Make sure that the remaining predicate does not contain subquery's columns - val nonPushDown = rest.flatMap(_.references).intersect(rightOp.output) + // Check if the remaining predicates do not contain columns from subquery + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) - if (pushDown.nonEmpty && nonPushDown.isEmpty) { + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { val pushDownPredicate = pushDown.reduce(And) val replaced = replaceAlias(pushDownPredicate, aliasMap) val newAggregate = aggregate.copy(child = @@ -990,7 +993,10 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } val stayUp = rest ++ containingNonDeterministic - if (pushDown.nonEmpty) { + // Check if the remaining predicates do not contain columns from subquery + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) + + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { val pushDownCond = pushDown.reduceLeft(And) val output = union.output val newGrandChildren = union.children.map { grandchild => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LeftSemiOrAntiPushdownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LeftSemiOrAntiPushdownSuite.scala index a1c35ca2a078..fb574ccb250e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LeftSemiOrAntiPushdownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LeftSemiOrAntiPushdownSuite.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.test.SharedSQLContext */ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { import testImplicits._ + import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Join} + import org.apache.spark.sql.catalyst.plans.LeftSemiOrAnti // setupTestData() @@ -82,6 +84,19 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { t5.createOrReplaceTempView("t5") } + private def checkLeftSemiOrAntiPlan(plan: LogicalPlan): Unit = { + plan match { + case j @ Join(_, _, LeftSemiOrAnti(_), _) => + // This is the expected result. + case _ => + fail( + s""" + |== FAIL: Top operator must be a LeftSemi or LeftAnti === + |${plan.toString} + """.stripMargin) + } + } + /** * TC 1.1: 1A-2B-3A-4B-5A-6A * Expected result: LeftAnti below Project @@ -336,8 +351,7 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { | select * | from join | where exists (select 1 from t3 where t3a = t1a and t3b >= 1) - """. - stripMargin) + """.stripMargin) val plan2 = sql( """ @@ -365,8 +379,7 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { | select * | from join | where not exists (select 1 from t3 where t3a = t1a and t3b >= 1) - """. - stripMargin) + """.stripMargin) val plan2 = sql( """ @@ -394,8 +407,7 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { | select * | from join | where exists (select 1 from t3 where t3a = t1a and t3b >= 1) - """. - stripMargin) + """.stripMargin) val plan2 = sql( """ @@ -424,8 +436,7 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { | select * | from join | where t2a not in (select t3a from t3 where t3b >= 1) - """. - stripMargin) + """.stripMargin) val plan2 = sql( """ @@ -454,8 +465,7 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { | select * | from join | where not exists (select 1 from t3 where t3a = t1a and t3b > t2b) - """. - stripMargin) + """.stripMargin) val plan2 = sql( """ @@ -471,7 +481,8 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { | on t3a = t1a and t3b > t2b """.stripMargin) checkAnswer(plan1, plan2) - // comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + val optPlan = plan1.queryExecution.optimizedPlan + checkLeftSemiOrAntiPlan(optPlan) plan1.show } /** @@ -487,8 +498,7 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { | select * | from join | where not exists (select 1 from t3 where t3a = t1a and t3b >= 1) - """. - stripMargin) + """.stripMargin) val plan2 = sql( """ @@ -516,8 +526,7 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { | select * | from join | where exists (select 1 from t3 where t3a = t2a and t3b >= 1) - """. - stripMargin) + """.stripMargin) val plan2 = sql( """ @@ -538,6 +547,7 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { * Expected result: No push down */ test("TC 2.9: LeftSemi over left semi join with correlated columns on the left table") { + import org.apache.spark.sql.catalyst.plans.logical.Union val plan1 = sql( """ @@ -545,9 +555,12 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { | (select * from t1 left semi join t2 on t1b = t2b and t2c >= 0) | select * | from join - | where exists (select 1 from t3 where t3a = t1a and t3c is not null) - """. - stripMargin) + | where exists (select 1 + | from (select * from t3 + | union all + | select * from t4) t3 + | where t3a = t1a and t3c is not null) + """.stripMargin) val plan2 = sql( """ @@ -558,11 +571,24 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { | on t1b = t2b and t2c >= 0) | select * | from join - | left semi join t3 + | left semi join + | (select * from t3 + | union all + | select * from t4) t3 | on t3a = t1a and t3c is not null """.stripMargin) checkAnswer(plan1, plan2) - // comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + val optPlan = plan1.queryExecution.optimizedPlan + optPlan match { + case j @ Join(_, _: Union, LeftSemiOrAnti(_), _) => + // This is the expected result. + case _ => + fail( + s""" + |== FAIL: The right operand of the top operator must be a Union === + |${optPlan.toString} + """.stripMargin) + } plan1.show } /** @@ -578,8 +604,7 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { | select * | from join | where exists (select 1 from t3 where t3a = t1a and t3a = t2a) - """. - stripMargin) + """.stripMargin) val plan2 = sql( """ @@ -594,7 +619,8 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { | on t3a = t1a and t3a = t2a """.stripMargin) checkAnswer(plan1, plan2) - // comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + val optPlan = plan1.queryExecution.optimizedPlan + checkLeftSemiOrAntiPlan(optPlan) plan1.show } /** @@ -610,8 +636,7 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { | select * | from join | where not exists (select 1 from t3 where t3b < -1) - """. - stripMargin) + """.stripMargin) val plan2 = sql( """ @@ -641,8 +666,7 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { | from join | where not exists (select 1 from t3 where t3b < -1) | and (t1c = 1 or t1c is null) - """. - stripMargin) + """.stripMargin) val plan2 = sql( """ @@ -658,4 +682,104 @@ class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) plan1.show } + /** + * TC 3.1: Negative case - LeftSemi over Aggregate + * Expected result: No push down + */ + test("TC 3.1: Negative case - LeftSemi over Aggregate") { + val plan1 = + sql( + """ + | select t1b, min(t1a) as min + | from t1 b + | group by t1b + | having t1b in (select t1b+1 + | from t1 a + | where a.t1a = min(b.t1a) ) + """.stripMargin) + val plan2 = + sql( + """ + | select b.* + | from (select t1b, min(t1a) as min + | from t1 + | group by t1b) b + | left semi join t1 + | on b.t1b = t1.t1b+1 + | and b.min = t1.t1a + | and t1.t1a is not null + """.stripMargin) + checkAnswer(plan1, plan2) + val optPlan = plan1.queryExecution.optimizedPlan + checkLeftSemiOrAntiPlan(optPlan) + plan1.show + } + /** + * TC 3.2: Negative case - LeftAnti over Window + * Expected result: No push down + */ + test("TC 3.2: Negative case - LeftAnti over Window") { + val plan1 = + sql( + """ + | select b.t1b, b.min + | from (select t1b, min(t1a) over (partition by t1b) min + | from t1) b + | where not exists (select 1 + | from t1 a + | where a.t1a = b.min + | and a.t1b = b.t1b) + """.stripMargin) + val plan2 = + sql( + """ + | select b.t1b, b.min + | from (select t1b, min(t1a) over (partition by t1b) min + | from t1) b + | left anti join t1 a + | on a.t1a = b.min + | and a.t1b = b.t1b + """.stripMargin) + checkAnswer(plan1, plan2) + val optPlan = plan1.queryExecution.optimizedPlan + checkLeftSemiOrAntiPlan(optPlan) + plan1.show + } + /** + * TC 3.3: Negative case - LeftSemi over Union + * Expected result: No push down + */ + test("TC 3.3: Negative case - LeftSemi over Union") { + val plan1 = + sql( + """ + | select un.t2b, un.t2a + | from (select t2b, t2a + | from t2 + | union all + | select t3b, t3a + | from t3) un + | where exists (select 1 + | from t1 a + | where a.t1b = un.t2b + | and a.t1a = un.t2a + case when rand() < 0 then 1 else 0 end) + """.stripMargin) + val plan2 = + sql( + """ + | select un.t2b, un.t2a + | from (select t2b, t2a + | from t2 + | union all + | select t3b, t3a + | from t3) un + | left semi join t1 a + | on a.t1b = un.t2b + | and a.t1a = un.t2a + """.stripMargin) + checkAnswer(plan1, plan2) + val optPlan = plan1.queryExecution.optimizedPlan + checkLeftSemiOrAntiPlan(optPlan) + plan1.show + } }