From 42ceb8003a2c312e8e80f01960cc983ad40c7fae Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 23 Jan 2025 11:03:06 +0800 Subject: [PATCH] [SPARK-50091][SQL] Handle case of aggregates in left-hand operand of IN-subquery This PR adds code to `RewritePredicateSubquery#apply` to explicitly handle the case where an `Aggregate` node contains an aggregate expression in the left-hand operand of an IN-subquery expression. The explicit handler moves the IN-subquery expressions out of the `Aggregate` and into a parent `Project` node. The `Aggregate` will continue to perform the aggregations that were used as an operand to the IN-subquery expression, but will not include the IN-subquery expression itself. After pulling up IN-subquery expressions into a Project node, `RewritePredicateSubquery#apply` is called again to handle the `Project` as a `UnaryNode`. The `Join` will now be inserted between the `Project` and the `Aggregate` node, and the join condition will use an attribute rather than an aggregate expression, e.g.: ``` Project [col1#32, exists#42 AS (sum(col2) IN (listquery()))#40] +- Join ExistenceJoin(exists#42), (sum(col2)#41L = c2#39L) :- Aggregate [col1#32], [col1#32, sum(col2#33) AS sum(col2)#41L] : +- LocalRelation [col1#32, col2#33] +- LocalRelation [c2#39L] ``` `sum(col2)#41L` in the above join condition, despite how it looks, is the name of the attribute, not an aggregate expression. The following query fails: ``` create or replace temp view v1(c1, c2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1); create or replace temp view v2(col1, col2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1); select col1, sum(col2) in (select c2 from v1) from v2 group by col1; ``` It fails with this error: ``` [INTERNAL_ERROR] Cannot generate code for expression: sum(input[1, int, false]) SQLSTATE: XX000 ``` With SPARK_TESTING=1, it fails with this error: ``` [PLAN_VALIDATION_FAILED_RULE_IN_BATCH] Rule org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery in batch RewriteSubquery generated an invalid plan: Special expressions are placed in the wrong plan: Aggregate [col1#11], [col1#11, first(exists#20, false) AS (sum(col2) IN (listquery()))#19] +- Join ExistenceJoin(exists#20), (sum(col2#12) = c2#18L) :- LocalRelation [col1#11, col2#12] +- LocalRelation [c2#18L] ``` The issue is that `RewritePredicateSubquery` builds a `Join` operator where the join condition contains an aggregate expression. The bug is in the handler for `UnaryNode` in `RewritePredicateSubquery#apply`, which adds a `Join` below the `Aggregate` and assumes that the left-hand operand of IN-subquery can be used in the join condition. This works fine for most cases, but not when the left-hand operand is an aggregate expression. This PR moves the offending IN-subqueries to a `Project` node, with the aggregates replaced by attributes referring to the aggregate expressions. The resulting join condition now uses those attributes rather than the actual aggregate expressions. No, other than allowing this type of query to succeed. New unit tests. No. Closes #48627 from bersprockets/aggregate_in_set_issue. Authored-by: Bruce Robbins Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/subquery.scala | 96 +++++++++++++++++-- .../optimizer/RewriteSubquerySuite.scala | 19 +++- .../org/apache/spark/sql/SubquerySuite.scala | 30 ++++++ 3 files changed, 136 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index ee2005315781..0652ee221c35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery +import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -100,6 +101,25 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } } + def exprsContainsAggregateInSubquery(exprs: Seq[Expression]): Boolean = { + exprs.exists { expr => + exprContainsAggregateInSubquery(expr) + } + } + + def exprContainsAggregateInSubquery(expr: Expression): Boolean = { + expr.exists { + case InSubquery(values, _) => + values.exists { v => + v.exists { + case _: AggregateExpression => true + case _ => false + } + } + case _ => false; + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY)) { case Filter(condition, child) @@ -162,15 +182,75 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { Project(p.output, Filter(newCond.get, inputPlan)) } + // Handle the case where the left-hand side of an IN-subquery contains an aggregate. + // + // If an Aggregate node contains such an IN-subquery, this handler will pull up all + // expressions from the Aggregate node into a new Project node. The new Project node + // will then be handled by the Unary node handler. + // + // The Unary node handler uses the left-hand side of the IN-subquery in a + // join condition. Thus, without this pre-transformation, the join condition + // contains an aggregate, which is illegal. With this pre-transformation, the + // join condition contains an attribute from the left-hand side of the + // IN-subquery contained in the Project node. + // + // For example: + // + // SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x + // FROM v2; + // + // The above query has this plan on entry to RewritePredicateSubquery#apply: + // + // Aggregate [(sum(col2#18) IN (list#12 []) AND (sum(col3#19) > -1)) AS x#13] + // : +- LocalRelation [c3#28L] + // +- LocalRelation [col2#18, col3#19] + // + // Note that the Aggregate node contains the IN-subquery and the left-hand + // side of the IN-subquery is an aggregate expression sum(col2#18)). + // + // This handler transforms the above plan into the following: + // scalastyle:off line.size.limit + // + // Project [(_aggregateexpression#20L IN (list#12 []) AND (_aggregateexpression#21L > -1)) AS x#13] + // : +- LocalRelation [c3#28L] + // +- Aggregate [sum(col2#18) AS _aggregateexpression#20L, sum(col3#19) AS _aggregateexpression#21L] + // +- LocalRelation [col2#18, col3#19] + // + // scalastyle:on + // Note that both the IN-subquery and the greater-than expressions have been + // pulled up into the Project node. These expressions use attributes + // (_aggregateexpression#20L and _aggregateexpression#21L) to refer to the aggregations + // which are still performed in the Aggregate node (sum(col2#18) and sum(col3#19)). + case p @ PhysicalAggregation( + groupingExpressions, aggregateExpressions, resultExpressions, child) + if exprsContainsAggregateInSubquery(p.expressions) => + val aggExprs = aggregateExpressions.map( + ae => Alias(ae, "_aggregateexpression")(ae.resultId)) + val aggExprIds = aggExprs.map(_.exprId).toSet + val resExprs = resultExpressions.map(_.transform { + case a: AttributeReference if aggExprIds.contains(a.exprId) => + a.withName("_aggregateexpression") + }.asInstanceOf[NamedExpression]) + // Rewrite the projection and the aggregate separately and then piece them together. + val newAgg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child) + val newProj = Project(resExprs, newAgg) + handleUnaryNode(newProj) + case u: UnaryNode if u.expressions.exists( - SubqueryExpression.hasInOrCorrelatedExistsSubquery) => - var newChild = u.child - u.mapExpressions(expr => { - val (newExpr, p) = rewriteExistentialExpr(Seq(expr), newChild) - newChild = p - // The newExpr can not be None - newExpr.get - }).withNewChildren(Seq(newChild)) + SubqueryExpression.hasInOrCorrelatedExistsSubquery) => handleUnaryNode(u) + } + + /** + * Handle the unary node case + */ + private def handleUnaryNode(u: UnaryNode): LogicalPlan = { + var newChild = u.child + u.mapExpressions(expr => { + val (newExpr, p) = rewriteExistentialExpr(Seq(expr), newChild) + newChild = p + // The newExpr can not be None + newExpr.get + }).withNewChildren(Seq(newChild)) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala index 17547bbcb940..c45a761353c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{IsNull, ListQuery, Not} +import org.apache.spark.sql.catalyst.expressions.{Cast, IsNull, ListQuery, Not} import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.LongType class RewriteSubquerySuite extends PlanTest { @@ -79,4 +80,20 @@ class RewriteSubquerySuite extends PlanTest { Optimize.executeAndTrack(query.analyze, tracker) assert(tracker.rules(RewritePredicateSubquery.ruleName).numEffectiveInvocations == 0) } + + test("SPARK-50091: Don't put aggregate expression in join condition") { + val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int) + val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int) + val plan = relation2.groupBy()(sum($"col2").in(ListQuery(relation1.select($"c3")))) + val optimized = Optimize.execute(plan.analyze) + val aggregate = relation2 + .select($"col2") + .groupBy()(sum($"col2").as("_aggregateexpression")) + val correctAnswer = aggregate + .join(relation1.select(Cast($"c3", LongType).as("c3")), + ExistenceJoin($"exists".boolean.withNullability(false)), + Some($"_aggregateexpression" === $"c3")) + .select($"exists".as("(sum(col2) IN (listquery()))")).analyze + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 260c992f1aed..04702201f82f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2800,4 +2800,34 @@ class SubquerySuite extends QueryTest checkAnswer(df3, Row(7)) } } + + test("SPARK-50091: Handle aggregates in left-hand operand of IN-subquery") { + withView("v1", "v2") { + Seq((1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8)) + .toDF("c1", "c2", "c3") + .createOrReplaceTempView("v1") + Seq((1, 2, 2), (1, 3, 3), (2, 2, 4), (3, 7, 7), (3, 1, 1)) + .toDF("col1", "col2", "col3") + .createOrReplaceTempView("v2") + + val df1 = sql("SELECT col1, SUM(col2) IN (SELECT c3 FROM v1) FROM v2 GROUP BY col1") + checkAnswer(df1, + Row(1, false) :: Row(2, true) :: Row(3, true) :: Nil) + + val df2 = sql("""SELECT + | col1, + | SUM(col2) IN (SELECT c3 FROM v1) and SUM(col3) IN (SELECT c2 FROM v1) AS x + |FROM v2 GROUP BY col1 + |ORDER BY col1""".stripMargin) + checkAnswer(df2, + Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil) + + val df3 = sql("""SELECT col1, (SUM(col2), SUM(col3)) IN (SELECT c3, c2 FROM v1) AS x + |FROM v2 + |GROUP BY col1 + |ORDER BY col1""".stripMargin) + checkAnswer(df3, + Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil) + } + } }