diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NondeterministicExpressionCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NondeterministicExpressionCollection.scala index d530cfe5175b..1dc5d79dc1b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NondeterministicExpressionCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NondeterministicExpressionCollection.scala @@ -38,7 +38,7 @@ object NondeterministicExpressionCollection { case namedExpression: NamedExpression => namedExpression case _ => Alias(nondeterministicExpr, "_nondeterministic")() } - nonDeterministicToAttributes.put(nondeterministicExpr, namedExpression) + nonDeterministicToAttributes.put(nondeterministicExpr.canonicalized, namedExpression) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala index 6769babdd1f1..09d3a6f93a87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala @@ -42,7 +42,7 @@ object PullOutNondeterministic extends Rule[LogicalPlan] { NondeterministicExpressionCollection.getNondeterministicToAttributes(a.groupingExpressions) val newChild = Project(a.child.output ++ nondeterToAttr.values.asScala.toSeq, a.child) val deterministicAggregate = a.transformExpressions { case e => - Option(nondeterToAttr.get(e)).map(_.toAttribute).getOrElse(e) + Option(nondeterToAttr.get(e.canonicalized)).map(_.toAttribute).getOrElse(e) }.copy(child = newChild) deterministicAggregate.groupingExpressions.foreach(expr => if (!expr.deterministic) { @@ -69,7 +69,7 @@ object PullOutNondeterministic extends Rule[LogicalPlan] { val nondeterToAttr = NondeterministicExpressionCollection.getNondeterministicToAttributes(p.expressions) val newPlan = p.transformExpressions { case e => - Option(nondeterToAttr.get(e)).map(_.toAttribute).getOrElse(e) + Option(nondeterToAttr.get(e.canonicalized)).map(_.toAttribute).getOrElse(e) } val newChild = Project(p.child.output ++ nondeterToAttr.values.asScala.toSeq, p.child) Project(p.output, newPlan.withNewChildren(newChild :: Nil)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 6d5456462d8d..694f182f1296 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -474,7 +474,8 @@ object IntegratedUDFTestUtils extends SQLHelper { * casted_col.cast(df.schema["col"].dataType) * }}} */ - case class TestPythonUDF(name: String, returnType: Option[DataType] = None) extends TestUDF { + case class TestPythonUDF(name: String, returnType: Option[DataType] = None, + deterministic: Boolean = true) extends TestUDF { private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction( name = name, func = SimplePythonFunction( @@ -487,7 +488,7 @@ object IntegratedUDFTestUtils extends SQLHelper { accumulator = null), dataType = StringType, pythonEvalType = PythonEvalType.SQL_BATCHED_UDF, - udfDeterministic = true) { + udfDeterministic = deterministic) { override def builder(e: Seq[Expression]): Expression = { assert(e.length == 1, "Defined UDF only has one column") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index 10603cc3aeaf..9b40226c2049 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row} -import org.apache.spark.sql.functions.{array, col, count, transform} +import org.apache.spark.sql.functions.{array, avg, col, count, transform} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.LongType @@ -139,4 +139,21 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { checkAnswer(df, Row(0, 1, 1, 0, 1, 1)) } } + + test("SPARK-53311: Nondeterministic Python UDF pull out in aggregate with grouping") { + assume(shouldTestPythonUDFs) + + // nondeterministic UDF + val pythonUDF = TestPythonUDF(name = "foo", Some(LongType), deterministic = false) + + // This query should work without throwing an analysis exception + // The UDF foo(value) appears in both grouping expressions and aggregate expressions + // The fix ensures that both instances are properly mapped to the same attribute + val df = spark.range(1) + .selectExpr("id", "id % 3 as value") + .groupBy(pythonUDF(col("value"))) + .agg(avg("id"), pythonUDF(col("value"))) + + checkAnswer(df, Row(0, 0.0, 0)) + } }