From 82b5bacce80064df1feb087293f7d13ad72334ca Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 25 Aug 2017 05:40:17 +0000 Subject: [PATCH 1/3] Deduplicate join output for correlated predicate subquery. --- .../sql/catalyst/optimizer/subquery.scala | 27 +++++++- .../org/apache/spark/sql/SubquerySuite.scala | 68 +++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 4386a1016276..62343da73cc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -49,6 +49,30 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } } + def dedupJoin(plan: LogicalPlan): LogicalPlan = { + plan transform { + case j @ Join(left, right, joinType, joinCond) => + val duplicates = right.outputSet.intersect(left.outputSet) + if (duplicates.nonEmpty) { + val aliasMap = AttributeMap(duplicates.map { dup => + dup -> Alias(dup, dup.toString)() + }.toSeq) + val aliasedExpressions = right.output.map { ref => + aliasMap.getOrElse(ref, ref) + } + val newRight = Project(aliasedExpressions, right) + val newJoinCond = joinCond.map { condExpr => + condExpr transform { + case a: Attribute => aliasMap.getOrElse(a, a).toAttribute + } + } + Join(left, newRight, joinType, newJoinCond) + } else { + j + } + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Filter(condition, child) => val (withSubquery, withoutSubquery) = @@ -61,7 +85,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } // Filter the plan by applying left semi and left anti joins. - withSubquery.foldLeft(newFilter) { + val rewritten = withSubquery.foldLeft(newFilter) { case (p, Exists(sub, conditions, _)) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftSemi, joinCond) @@ -98,6 +122,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) Project(p.output, Filter(newCond.get, inputPlan)) } + dedupJoin(rewritten) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 274694b99541..a02d99182816 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.test.SharedSQLContext class SubquerySuite extends QueryTest with SharedSQLContext { @@ -875,4 +876,71 @@ class SubquerySuite extends QueryTest with SharedSQLContext { assert(e.message.contains("cannot resolve '`a`' given input columns: [t.i, t.j]")) } } + + test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 1") { + withTable("t1") { + withTempPath { path => + Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath) + sql(s"CREATE TABLE t1 USING parquet LOCATION '${path.toURI}'") + + val sqlText = + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin + val ds = sql(sqlText) + val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan + val join = optimizedPlan.collect { + case j: Join => j + }.head.asInstanceOf[Join] + assert(join.duplicateResolved) + assert(optimizedPlan.resolved) + } + } + } + + test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 2") { + withTable("t1", "t2", "t3") { + withTempPath { path => + val data = Seq((1, 1, 1), (2, 0, 2)) + + data.toDF("t1a", "t1b", "t1c").write.parquet(path.getCanonicalPath + "/t1") + data.toDF("t2a", "t2b", "t2c").write.parquet(path.getCanonicalPath + "/t2") + data.toDF("t3a", "t3b", "t3c").write.parquet(path.getCanonicalPath + "/t3") + + sql(s"CREATE TABLE t1 USING parquet LOCATION '${path.toURI}/t1'") + sql(s"CREATE TABLE t2 USING parquet LOCATION '${path.toURI}/t2'") + sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.toURI}/t3'") + + val sqlText = + s""" + |SELECT * + |FROM (SELECT * + | FROM t2 + | WHERE t2c IN (SELECT t1c + | FROM t1 + | WHERE t1a = t2a) + | UNION + | SELECT * + | FROM t3 + | WHERE t3a IN (SELECT t2a + | FROM t2 + | UNION ALL + | SELECT t1a + | FROM t1 + | WHERE t1b > 0)) t4 + |WHERE t4.t2b IN (SELECT Min(t3b) + | FROM t3 + | WHERE t4.t2a = t3a) + """.stripMargin + val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan + val joinNodes = optimizedPlan.collect { + case j: Join => j + }.map(_.asInstanceOf[Join]) + joinNodes.map(j => assert(j.duplicateResolved)) + assert(optimizedPlan.resolved) + } + } + } } From c1325fb9b1f8501b1a31b61e9b39bf1213b021f7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 6 Sep 2017 03:01:31 +0000 Subject: [PATCH 2/3] Address comments. --- .../sql/catalyst/optimizer/subquery.scala | 58 ++++++++++--------- .../org/apache/spark/sql/SubquerySuite.scala | 1 - 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 62343da73cc2..7ff891516dac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -49,28 +49,31 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } } - def dedupJoin(plan: LogicalPlan): LogicalPlan = { - plan transform { - case j @ Join(left, right, joinType, joinCond) => - val duplicates = right.outputSet.intersect(left.outputSet) - if (duplicates.nonEmpty) { - val aliasMap = AttributeMap(duplicates.map { dup => - dup -> Alias(dup, dup.toString)() - }.toSeq) - val aliasedExpressions = right.output.map { ref => - aliasMap.getOrElse(ref, ref) - } - val newRight = Project(aliasedExpressions, right) - val newJoinCond = joinCond.map { condExpr => - condExpr transform { - case a: Attribute => aliasMap.getOrElse(a, a).toAttribute - } + private def dedupJoin(joinPlan: Join): Join = joinPlan match { + // SPARK-21835: It is possibly that the two sides of the join have conflicting attributes, + // the produced join then becomes unresolved and break structural integrity. We should + // de-duplicate conflicting attributes. We don't use transformation here because we only + // care about the most top join converted from correlated predicate subquery. + case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti), joinCond) => + val duplicates = right.outputSet.intersect(left.outputSet) + if (duplicates.nonEmpty) { + val aliasMap = AttributeMap(duplicates.map { dup => + dup -> Alias(dup, dup.toString)() + }.toSeq) + val aliasedExpressions = right.output.map { ref => + aliasMap.getOrElse(ref, ref) + } + val newRight = Project(aliasedExpressions, right) + val newJoinCond = joinCond.map { condExpr => + condExpr transform { + case a: Attribute => aliasMap.getOrElse(a, a).toAttribute } - Join(left, newRight, joinType, newJoinCond) - } else { - j } - } + Join(left, newRight, joinType, newJoinCond) + } else { + j + } + case _ => joinPlan } def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -85,17 +88,20 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } // Filter the plan by applying left semi and left anti joins. - val rewritten = withSubquery.foldLeft(newFilter) { + withSubquery.foldLeft(newFilter) { case (p, Exists(sub, conditions, _)) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) - Join(outerPlan, sub, LeftSemi, joinCond) + // Deduplicate conflicting attributes if any. + dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) case (p, Not(Exists(sub, conditions, _))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) - Join(outerPlan, sub, LeftAnti, joinCond) + // Deduplicate conflicting attributes if any. + dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond)) case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) => val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) - Join(outerPlan, sub, LeftSemi, joinCond) + // Deduplicate conflicting attributes if any. + dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond)) case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive @@ -117,12 +123,12 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // will have the final conditions in the LEFT ANTI as // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) val pairs = (joinConds.map(c => Or(c, IsNull(c))) ++ conditions).reduceLeft(And) - Join(outerPlan, sub, LeftAnti, Option(pairs)) + // Deduplicate conflicting attributes if any. + dedupJoin(Join(outerPlan, sub, LeftAnti, Option(pairs))) case (p, predicate) => val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) Project(p.output, Filter(newCond.get, inputPlan)) } - dedupJoin(rewritten) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index a02d99182816..0d56d777e8f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -889,7 +889,6 @@ class SubquerySuite extends QueryTest with SharedSQLContext { |WHERE |NOT EXISTS (SELECT * FROM t1) """.stripMargin - val ds = sql(sqlText) val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan val join = optimizedPlan.collect { case j: Join => j From 85508287ca1b98f3a3c341efd3ac70f99b56bc73 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 6 Sep 2017 04:12:27 +0000 Subject: [PATCH 3/3] Address comment. --- .../scala/org/apache/spark/sql/SubquerySuite.scala | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 0d56d777e8f6..ee6905e999df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -890,9 +890,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { |NOT EXISTS (SELECT * FROM t1) """.stripMargin val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan - val join = optimizedPlan.collect { - case j: Join => j - }.head.asInstanceOf[Join] + val join = optimizedPlan.collectFirst { case j: Join => j }.get assert(join.duplicateResolved) assert(optimizedPlan.resolved) } @@ -934,10 +932,8 @@ class SubquerySuite extends QueryTest with SharedSQLContext { | WHERE t4.t2a = t3a) """.stripMargin val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan - val joinNodes = optimizedPlan.collect { - case j: Join => j - }.map(_.asInstanceOf[Join]) - joinNodes.map(j => assert(j.duplicateResolved)) + val joinNodes = optimizedPlan.collect { case j: Join => j } + joinNodes.foreach(j => assert(j.duplicateResolved)) assert(optimizedPlan.resolved) } }