diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 801bd2693af4..5aef82b64ed3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -400,13 +400,14 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { (distinctAggOperatorMap.flatMap(_._2) ++ regularAggOperatorMap.map(e => (e._1, e._3))).toMap + val groupByMapNonFoldable = groupByMap.filter(!_._1.foldable) val patchedAggExpressions = a.aggregateExpressions.map { e => e.transformDown { case e: Expression => // The same GROUP BY clauses can have different forms (different names for instance) in // the groupBy and aggregate expressions of an aggregate. This makes a map lookup // tricky. So we do a linear search for a semantically equal group by expression. - groupByMap + groupByMapNonFoldable .find(ge => e.semanticEquals(ge._1)) .map(_._2) .getOrElse(transformations.getOrElse(e, e)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index ac136dfb898e..4d31999ded65 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.{Literal, Round} import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} @@ -109,4 +109,20 @@ class RewriteDistinctAggregatesSuite extends PlanTest { case _ => fail(s"Plan is not rewritten:\n$rewrite") } } + + test("SPARK-49261: Literals in grouping expressions shouldn't result in unresolved aggregation") { + val relation = testRelation2 + .select(Literal(6).as("gb"), $"a", $"b", $"c", $"d") + val input = relation + .groupBy($"a", $"gb")( + countDistinct($"b").as("agg1"), + countDistinct($"d").as("agg2"), + Round(sum($"c").as("sum1"), 6)).analyze + val rewriteFold = FoldablePropagation(input) + // without the fix, the below produces an unresolved plan + val rewrite = RewriteDistinctAggregates(rewriteFold) + if (!rewrite.resolved) { + fail(s"Plan is not as expected:\n$rewrite") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 0e9d34c3bd96..e80c3b23a7db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2490,6 +2490,27 @@ class DataFrameAggregateSuite extends QueryTest }) } } + + test("SPARK-49261: Literals in grouping expressions shouldn't result in unresolved aggregation") { + val data = Seq((1, 1.001d, 2), (2, 3.001d, 4), (2, 3.001, 4)).toDF("a", "b", "c") + withTempView("v1") { + data.createOrReplaceTempView("v1") + val df = + sql("""SELECT + | ROUND(SUM(b), 6) AS sum1, + | COUNT(DISTINCT a) AS count1, + | COUNT(DISTINCT c) AS count2 + |FROM ( + | SELECT + | 6 AS gb, + | * + | FROM v1 + |) + |GROUP BY a, gb + |""".stripMargin) + checkAnswer(df, Row(1.001d, 1, 1) :: Row(6.002d, 1, 1) :: Nil) + } + } } case class B(c: Option[Double])