diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1db44496e67c..d37cbb2755b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -989,7 +989,7 @@ class Analyzer( withPosition(u) { try { outer.resolve(nameParts, resolver) match { - case Some(outerAttr) => OuterReference(outerAttr) + case Some(outerAttr) => OuterReference(outerAttr)() case None => u } } catch { @@ -1008,7 +1008,9 @@ class Analyzer( * * This method returns the rewritten subquery and correlated predicates. */ - private def pullOutCorrelatedPredicates(sub: LogicalPlan): (LogicalPlan, Seq[Expression]) = { + private def pullOutCorrelatedProjectionAndPredicates(sub: LogicalPlan) + : (LogicalPlan, Seq[NamedExpression], Seq[Expression]) = { + val outerProjectionSet = scala.collection.mutable.Set.empty[NamedExpression] val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]] /** Make sure a plans' subtree does not contain a tagged predicate. */ @@ -1077,7 +1079,7 @@ class Analyzer( // Simplify the predicates before pulling them out. val transformed = BooleanSimplification(sub) transformUp { // WARNING: - // Only Filter can host correlated expressions at this time + // Only Filter and Project can host correlated expressions at this time // Anyone adding a new "case" below needs to add the call to // "failOnOuterReference" to disallow correlated expressions in it. case f @ Filter(cond, child) => @@ -1102,12 +1104,19 @@ class Analyzer( child } case p @ Project(expressions, child) => - failOnOuterReference(p) + outerProjectionSet ++= expressions.filter(containsOuter) + val referencesToAdd = missingReferences(p) - if (referencesToAdd.nonEmpty) { - Project(expressions ++ referencesToAdd, child) + val newProjectList = if (referencesToAdd.nonEmpty) { + expressions ++ referencesToAdd } else { - p + expressions + }.filterNot(x => outerProjectionSet.contains(x)) + + if (newProjectList.isEmpty) { + p.copy(projectList = Seq(Alias(Literal(1), "1")())) + } else { + p.copy(projectList = newProjectList) } case a @ Aggregate(grouping, expressions, child) => failOnOuterReference(a) @@ -1162,7 +1171,7 @@ class Analyzer( failOnOuterReference(p) p } - (transformed, predicateMap.values.flatten.toSeq) + (transformed, outerProjectionSet.toSeq, predicateMap.values.flatten.toSeq) } /** @@ -1171,9 +1180,9 @@ class Analyzer( */ private def rewriteSubQuery( sub: LogicalPlan, - outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { + outer: Seq[LogicalPlan]): (LogicalPlan, Seq[NamedExpression], Seq[Expression]) = { // Pull out the tagged predicates and rewrite the subquery in the process. - val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub) + val (basePlan, baseOutputs, baseConditions) = pullOutCorrelatedProjectionAndPredicates(sub) // Make sure the inner and the outer query attributes do not collide. val outputSet = outer.map(_.outputSet).reduce(_ ++ _) @@ -1199,7 +1208,11 @@ class Analyzer( val conditions = deDuplicatedConditions.map(_.transform { case OuterReference(ref) => ref }) - (plan, conditions) + val outputs = baseOutputs.map(_.transform { + case OuterReference(ref) => ref + }).asInstanceOf[Seq[NamedExpression]] + + (plan, outputs, conditions) } /** @@ -1213,7 +1226,8 @@ class Analyzer( e: SubqueryExpression, plans: Seq[LogicalPlan], requiredColumns: Int = 0)( - f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = { + f: (LogicalPlan, Seq[NamedExpression], Seq[Expression]) => SubqueryExpression) + : SubqueryExpression = { // Step 1: Resolve the outer expressions. var previous: LogicalPlan = null var current = e.plan @@ -1252,18 +1266,26 @@ class Analyzer( private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => - resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) + resolveSubQuery(s, plans, 1) { (plan, _, children) => + ScalarSubquery(plan, children, exprId) + } case e @ Exists(sub, exprId) => - resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId)) + resolveSubQuery(e, plans) { (plan, _, children) => + PredicateSubquery(plan, children, nullAware = false, exprId) + } case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => // Get the left hand side expressions. val expressions = e match { case cns : CreateNamedStruct => cns.valExprs case expr => Seq(expr) } - resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) => + resolveSubQuery(l, plans, expressions.size) { (rewrite, exprs, conditions) => // Construct the IN conditions. - val inConditions = expressions.zip(rewrite.output).map(EqualTo.tupled) + val inConditions = if (exprs.isEmpty) { + expressions.zip(rewrite.output).map(EqualTo.tupled) + } else { + expressions.zip(exprs).map(EqualTo.tupled) + } PredicateSubquery(rewrite, inConditions ++ conditions, nullAware = true, exprId) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 127475713605..c562d10cabe6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -356,10 +356,17 @@ case class PrettyAttribute( * A place holder used to hold a reference that has been resolved to a field outside of the current * plan. This is used for correlated subqueries. */ -case class OuterReference(e: NamedExpression) extends LeafExpression with Unevaluable { +case class OuterReference(e: NamedExpression)( + val exprId: ExprId = NamedExpression.newExprId) + extends LeafExpression with NamedExpression with Unevaluable { override def dataType: DataType = e.dataType override def nullable: Boolean = e.nullable override def prettyName: String = "outer" + + override def name: String = e.name + override def qualifier: Option[String] = e.qualifier + override def toAttribute: Attribute = e.toAttribute + override def newInstance(): NamedExpression = OuterReference(e)() } object VirtualColumn { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 8c1faea2394c..a67cffe0c1bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -515,7 +515,7 @@ class AnalysisErrorSuite extends AnalysisTest { Exists( Join( LocalRelation(b), - Filter(EqualTo(OuterReference(a), c), LocalRelation(c)), + Filter(EqualTo(OuterReference(a)(), c), LocalRelation(c)), LeftOuter, Option(EqualTo(b, c)))), LocalRelation(a)) @@ -524,7 +524,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan2 = Filter( Exists( Join( - Filter(EqualTo(OuterReference(a), c), LocalRelation(c)), + Filter(EqualTo(OuterReference(a)(), c), LocalRelation(c)), LocalRelation(b), RightOuter, Option(EqualTo(b, c)))), @@ -532,14 +532,14 @@ class AnalysisErrorSuite extends AnalysisTest { assertAnalysisError(plan2, "Accessing outer query column is not allowed in" :: Nil) val plan3 = Filter( - Exists(Union(LocalRelation(b), Filter(EqualTo(OuterReference(a), c), LocalRelation(c)))), + Exists(Union(LocalRelation(b), Filter(EqualTo(OuterReference(a)(), c), LocalRelation(c)))), LocalRelation(a)) assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil) val plan4 = Filter( Exists( Limit(1, - Filter(EqualTo(OuterReference(a), b), LocalRelation(b))) + Filter(EqualTo(OuterReference(a)(), b), LocalRelation(b))) ), LocalRelation(a)) assertAnalysisError(plan4, "Accessing outer query column is not allowed in a LIMIT" :: Nil) @@ -547,7 +547,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan5 = Filter( Exists( Sample(0.0, 0.5, false, 1L, - Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))().select('b) + Filter(EqualTo(OuterReference(a)(), b), LocalRelation(b)))().select('b) ), LocalRelation(a)) assertAnalysisError(plan5, diff --git a/sql/core/src/test/resources/sql-tests/inputs/subqueries.sql b/sql/core/src/test/resources/sql-tests/inputs/subqueries.sql new file mode 100644 index 000000000000..4041902dc6ff --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subqueries.sql @@ -0,0 +1,18 @@ +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES 1, 2 AS t1(a); + +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES 1 AS t2(b); + +-- IN with correlated predicate +SELECT a FROM t1 WHERE a IN (SELECT b FROM t2 WHERE a=b); + +-- NOT IN with correlated predicate +SELECT a FROM t1 WHERE a NOT IN (SELECT b FROM t2 WHERE a=b); + +-- IN with correlated projection +SELECT a FROM t1 WHERE a IN (SELECT a FROM t2); + +-- IN with correlated projection +SELECT a FROM t1 WHERE a NOT IN (SELECT a FROM t2); + +-- IN with expressions +SELECT a FROM t1 WHERE a*1 IN (SELECT a%2 FROM t2); diff --git a/sql/core/src/test/resources/sql-tests/results/subqueries.sql.out b/sql/core/src/test/resources/sql-tests/results/subqueries.sql.out new file mode 100644 index 000000000000..a3486e9cc1c0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subqueries.sql.out @@ -0,0 +1,59 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES 1, 2 AS t1(a) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES 1 AS t2(b) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT a FROM t1 WHERE a IN (SELECT b FROM t2 WHERE a=b) +-- !query 2 schema +struct +-- !query 2 output +1 + + +-- !query 3 +SELECT a FROM t1 WHERE a NOT IN (SELECT b FROM t2 WHERE a=b) +-- !query 3 schema +struct +-- !query 3 output +2 + + +-- !query 4 +SELECT a FROM t1 WHERE a IN (SELECT a FROM t2) +-- !query 4 schema +struct +-- !query 4 output +1 +2 + + +-- !query 5 +SELECT a FROM t1 WHERE a NOT IN (SELECT a FROM t2) +-- !query 5 schema +struct +-- !query 5 output + + + +-- !query 6 +SELECT a FROM t1 WHERE a*1 IN (SELECT a%2 FROM t2) +-- !query 6 schema +struct +-- !query 6 output +1