From abc9a0d2af35b02940e78c693999f3ea8a652d8f Mon Sep 17 00:00:00 2001 From: Mihailo Timotic Date: Tue, 20 May 2025 17:36:51 +0200 Subject: [PATCH] fix --- .../connect/planner/SparkConnectPlanner.scala | 26 ++++++++------- .../planner/SparkConnectPlannerSuite.scala | 33 ++++++++++++++++++- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index d821046d0831..1bdf408664e4 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2457,11 +2457,13 @@ class SparkConnectPlanner( input } - val groupingExpressionsWithOrdinals = rel.getGroupingExpressionsList.asScala.toSeq - .map(transformGroupingExpressionAndReplaceOrdinals) + val groupingExpressions = + rel.getGroupingExpressionsList.asScala.toSeq.map(transformExpression) + val groupingExpressionsWithOrdinals = + groupingExpressions.map(replaceOrdinalsInGroupingExpressions) val aggExprs = rel.getAggregateExpressionsList.asScala.toSeq .map(expr => transformExpressionWithTypedReduceExpression(expr, logicalPlan)) - val aliasedAgg = (groupingExpressionsWithOrdinals ++ aggExprs).map(toNamedExpression) + val aliasedAgg = (groupingExpressions ++ aggExprs).map(toNamedExpression) rel.getGroupType match { case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY => @@ -2506,7 +2508,10 @@ class SparkConnectPlanner( val groupingSetsExpressionsWithOrdinals = rel.getGroupingSetsList.asScala.toSeq.map { getGroupingSets => getGroupingSets.getGroupingSetList.asScala.toSeq - .map(transformGroupingExpressionAndReplaceOrdinals) + .map(groupingExpressions => { + val transformedGroupingExpression = transformExpression(groupingExpressions) + replaceOrdinalsInGroupingExpressions(transformedGroupingExpression) + }) } logical.Aggregate( groupingExpressions = Seq( @@ -2521,18 +2526,15 @@ class SparkConnectPlanner( } /** - * Transforms an input protobuf grouping expression into the Catalyst expression and converts - * top-level integer [[Literal]]s to [[UnresolvedOrdinal]]s, if `groupByOrdinal` is enabled. + * Replaces top-level integer [[Literal]]s to [[UnresolvedOrdinal]]s, if `groupByOrdinal` is + * enabled. */ - private def transformGroupingExpressionAndReplaceOrdinals( - groupingExpression: proto.Expression) = { - val transformedGroupingExpression = transformExpression(groupingExpression) + private def replaceOrdinalsInGroupingExpressions(groupingExpression: Expression) = if (session.sessionState.conf.groupByOrdinal) { - replaceIntegerLiteralWithOrdinal(transformedGroupingExpression) + replaceIntegerLiteralWithOrdinal(groupingExpression) } else { - transformedGroupingExpression + groupingExpression } - } @deprecated("TypedReduce is now implemented using a normal UDAF aggregator.", "4.0.0") private def transformTypedReduceExpression( diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 72f7065b4424..9a69151ccfc2 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -29,7 +29,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedFunction, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.trees.TreePattern import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.connect.SparkConnectTestUtils @@ -922,4 +923,34 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { assert(fn3.nameParts.head == "abcde") assert(fn3.isInternal) } + + test("SPARK-51820 aggregate list should not contain UnresolvedOrdinal") { + val ordinal = proto.Expression + .newBuilder() + .setLiteral(proto.Expression.Literal.newBuilder().setInteger(1).build()) + .build() + + val sum = + proto.Expression + .newBuilder() + .setUnresolvedFunction( + proto.Expression.UnresolvedFunction + .newBuilder() + .setFunctionName("sum") + .addArguments(ordinal)) + .build() + + val aggregate = proto.Aggregate.newBuilder + .setInput(readRel) + .addAggregateExpressions(sum) + .addGroupingExpressions(ordinal) + .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) + .build() + + val plan = + transform(proto.Relation.newBuilder.setAggregate(aggregate).build()).asInstanceOf[Aggregate] + + assert(plan.aggregateExpressions.forall(aggregateExpression => + !aggregateExpression.containsPattern(TreePattern.UNRESOLVED_ORDINAL))) + } }