diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e30aa0a79692..76e03d3d2ded 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1283,6 +1283,14 @@ def test_collect_functions(self): sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r), ["1", "2", "2", "2"]) + def test_udf_with_aggregate_function(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql.functions import udf, col + from pyspark.sql.types import BooleanType + + my_filter = udf(lambda a: a == 1, BooleanType()) + sel = df.select(col("key")).distinct().filter(my_filter(col("key"))) + self.assertEqual(sel.collect(), [Row(key=1)]) if __name__ == "__main__": from pyspark.sql.tests import * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 6e76e9569feb..0601cbd40733 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule /** @@ -64,11 +64,18 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") // Trim away the new UDF value if it was only used for filtering or something. - logical.Project( - plan.output, - plan.transformExpressions { - case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute - }.withNewChildren(newChildren)) + val transformed = plan.transformExpressions { + case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute + }.withNewChildren(newChildren) + + // If plan is an [[Aggregate]], return evaluated plan as is for + // [[ResolveAggregateFunctions]] rule in a batch for further resolution. + // Otherwise, construct a [[Project]] with evaluated udf. + if (plan.isInstanceOf[Aggregate]) { + transformed + } else { + logical.Project(plan.output, transformed) + } case None => // If there is no Python UDF that is resolved, skip this round.