diff --git a/python/pyspark/sql/functions/__init__.py b/python/pyspark/sql/functions/__init__.py index ac9ae67ac446..61bf70c69f9f 100644 --- a/python/pyspark/sql/functions/__init__.py +++ b/python/pyspark/sql/functions/__init__.py @@ -534,7 +534,8 @@ "user", # "uuid": Excluded because of the name conflict with builtin uuid module "version", - # UDF, UDTF and UDT + # UDF, UDAF, UDTF and UDT + "Aggregator", "AnalyzeArgument", "AnalyzeResult", "ArrowUDFType", @@ -544,6 +545,7 @@ "SelectedColumn", "SkipRestOfInputTableException", "UserDefinedFunction", + "UserDefinedAggregateFunction", "UserDefinedTableFunction", "arrow_udf", # Geospatial ST Functions @@ -556,6 +558,7 @@ "call_udf", "pandas_udf", "udf", + "udaf", "udtf", "arrow_udtf", "unwrap_udt", diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 04b800be2372..96a6ac2f6929 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -54,6 +54,7 @@ # Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 from pyspark.sql.udf import UserDefinedFunction, _create_py_udf # noqa: F401 +from pyspark.sql.udaf import Aggregator, UserDefinedAggregateFunction, udaf # noqa: F401 from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult # noqa: F401 from pyspark.sql.udtf import OrderingColumn, PartitioningColumn, SelectedColumn # noqa: F401 from pyspark.sql.udtf import SkipRestOfInputTableException # noqa: F401 diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 939f7ff6b610..bdcfc25f3fc0 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -187,6 +187,17 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame": # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" exprs = cast(Tuple[Column, ...], exprs) + + # Check if any column is a UDAF column (has _udaf_func attribute) + from pyspark.sql.udaf import _handle_udaf_aggregation_in_grouped_data + + udaf_cols = [c for c in exprs if hasattr(c, "_udaf_func")] + if udaf_cols: + return _handle_udaf_aggregation_in_grouped_data( + self._df, self._jgd, exprs, udaf_cols + ) + + # Normal column aggregation jdf = self._jgd.agg(exprs[0]._jc, _to_seq(self.session._sc, [c._jc for c in exprs[1:]])) return DataFrame(jdf, self.session) diff --git a/python/pyspark/sql/tests/test_udaf.py b/python/pyspark/sql/tests/test_udaf.py new file mode 100644 index 000000000000..93ac69e199ca --- /dev/null +++ b/python/pyspark/sql/tests/test_udaf.py @@ -0,0 +1,612 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import unittest +from collections import Counter +from decimal import Decimal as D + +from pyspark.sql import Row +from pyspark.sql.functions import udaf +from pyspark.sql.types import ( + ArrayType, + BinaryType, + DecimalType, + IntegerType, + LongType, + MapType, + StringType, + StructField, + StructType, +) +from pyspark.sql.udaf import Aggregator, UserDefinedAggregateFunction +from pyspark.errors import PySparkTypeError, PySparkNotImplementedError +from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.testing.utils import assertDataFrameEqual + + +class UDAFTestsMixin: + # ============ Core Functionality Tests ============ + + def test_udaf_basic_aggregations(self): + """Test basic aggregation types: sum, average, max.""" + + class SumAgg(Aggregator): + @staticmethod + def zero(): + return 0 + + @staticmethod + def reduce(buffer, value): + return buffer + (value or 0) + + @staticmethod + def merge(buffer1, buffer2): + return buffer1 + buffer2 + + @staticmethod + def finish(reduction): + return reduction + + class AvgAgg(Aggregator): + @staticmethod + def zero(): + return (0.0, 0) + + @staticmethod + def reduce(buffer, value): + if value is not None: + return (buffer[0] + value, buffer[1] + 1) + return buffer + + @staticmethod + def merge(buffer1, buffer2): + return (buffer1[0] + buffer2[0], buffer1[1] + buffer2[1]) + + @staticmethod + def finish(reduction): + return reduction[0] / reduction[1] if reduction[1] else None + + class MaxAgg(Aggregator): + @staticmethod + def zero(): + return None + + @staticmethod + def reduce(buffer, value): + if value is None: + return buffer + return max(buffer, value) if buffer is not None else value + + @staticmethod + def merge(buffer1, buffer2): + if buffer1 is None: + return buffer2 + if buffer2 is None: + return buffer1 + return max(buffer1, buffer2) + + @staticmethod + def finish(reduction): + return reduction + + df = self.spark.createDataFrame([(1.0,), (2.0,), (3.0,), (4.0,), (5.0,)], ["value"]) + + # Sum + sum_udaf = udaf(SumAgg(), "double", "MySum") + result = df.agg(sum_udaf(df.value)) + assertDataFrameEqual(result, [Row(**{"MySum(value)": 15.0})]) + + # Average (uses tuple buffer) + avg_udaf = udaf(AvgAgg(), "double", "MyAvg") + result = df.agg(avg_udaf(df.value)) + assertDataFrameEqual(result, [Row(**{"MyAvg(value)": 3.0})]) + + # Max + max_udaf = udaf(MaxAgg(), "double", "MyMax") + result = df.agg(max_udaf(df.value)) + assertDataFrameEqual(result, [Row(**{"MyMax(value)": 5.0})]) + + def test_udaf_with_groupby(self): + """Test UDAF with groupBy - multiple groups, column order independent.""" + + class SumAgg(Aggregator): + @staticmethod + def zero(): + return 0 + + @staticmethod + def reduce(buffer, value): + return buffer + (value or 0) + + @staticmethod + def merge(buffer1, buffer2): + return buffer1 + buffer2 + + @staticmethod + def finish(reduction): + return reduction + + # Test with key column BEFORE value column (previously failed) + df = self.spark.createDataFrame( + [("a", 1), ("a", 2), ("a", 3), ("b", 10), ("b", 20), ("c", 100)], + ["key", "value"], + ) + sum_udaf = udaf(SumAgg(), "bigint", "MySum") + result = df.groupBy("key").agg(sum_udaf(df.value)) + expected = [ + Row(key="a", **{"MySum(value)": 6}), + Row(key="b", **{"MySum(value)": 30}), + Row(key="c", **{"MySum(value)": 100}), + ] + assertDataFrameEqual(result, expected, checkRowOrder=False) + + def test_udaf_edge_cases(self): + """Test edge cases: nulls, empty dataframe.""" + + class SumAgg(Aggregator): + @staticmethod + def zero(): + return 0 + + @staticmethod + def reduce(buffer, value): + return buffer + (value or 0) + + @staticmethod + def merge(buffer1, buffer2): + return buffer1 + buffer2 + + @staticmethod + def finish(reduction): + return reduction + + sum_udaf = udaf(SumAgg(), "bigint", "MySum") + + # Nulls + df = self.spark.createDataFrame([(1,), (None,), (3,), (None,), (5,)], ["value"]) + result = df.agg(sum_udaf(df.value)) + assertDataFrameEqual(result, [Row(**{"MySum(value)": 9})]) + + # Empty DataFrame + empty_df = self.spark.createDataFrame([], "value: int") + result = empty_df.agg(sum_udaf(empty_df.value)) + assertDataFrameEqual(result, [Row(**{"MySum(value)": 0})]) + + def test_udaf_large_dataset(self): + """Test UDAF with large dataset to verify multi-partition aggregation.""" + + class AvgAgg(Aggregator): + @staticmethod + def zero(): + return (0.0, 0) + + @staticmethod + def reduce(buffer, value): + if value is not None: + return (buffer[0] + value, buffer[1] + 1) + return buffer + + @staticmethod + def merge(buffer1, buffer2): + return (buffer1[0] + buffer2[0], buffer1[1] + buffer2[1]) + + @staticmethod + def finish(reduction): + return reduction[0] / reduction[1] if reduction[1] else None + + data = [(float(i),) for i in range(1, 20001)] + df = self.spark.createDataFrame(data, ["value"]).repartition(8) + + avg_udaf = udaf(AvgAgg(), "double", "MyAvg") + result = df.agg(avg_udaf(df.value)) + expected_avg = 20001.0 / 2.0 + assertDataFrameEqual(result, [Row(**{"MyAvg(value)": expected_avg})]) + + # ============ Input Type Tests (Consolidated) ============ + + def test_udaf_various_input_types(self): + """Test UDAF with various Spark input types.""" + + class CountAgg(Aggregator): + @staticmethod + def zero(): + return 0 + + @staticmethod + def reduce(buffer, value): + return buffer + (1 if value is not None else 0) + + @staticmethod + def merge(buffer1, buffer2): + return buffer1 + buffer2 + + @staticmethod + def finish(reduction): + return reduction + + count_udaf = udaf(CountAgg(), "bigint", "Count") + + # String input + df_str = self.spark.createDataFrame([("a",), ("b",), ("c",)], ["value"]) + result = df_str.agg(count_udaf(df_str.value)) + assertDataFrameEqual(result, [Row(**{"Count(value)": 3})]) + + # Boolean input + df_bool = self.spark.createDataFrame([(True,), (False,), (True,), (None,)], ["flag"]) + result = df_bool.agg(count_udaf(df_bool.flag)) + assertDataFrameEqual(result, [Row(**{"Count(flag)": 3})]) + + # Date input + df_date = self.spark.createDataFrame( + [(datetime.date(2024, 1, 1),), (datetime.date(2024, 1, 10),), (None,)], + ["dt"], + ) + result = df_date.agg(count_udaf(df_date.dt)) + assertDataFrameEqual(result, [Row(**{"Count(dt)": 2})]) + + # Timestamp input + df_ts = self.spark.createDataFrame( + [(datetime.datetime(2024, 1, 1, 10, 0),), (datetime.datetime(2024, 1, 1, 11, 0),)], + ["ts"], + ) + result = df_ts.agg(count_udaf(df_ts.ts)) + assertDataFrameEqual(result, [Row(**{"Count(ts)": 2})]) + + # Decimal input + schema = StructType([StructField("amount", DecimalType(10, 2), True)]) + df_dec = self.spark.createDataFrame([(D("10.50"),), (D("20.25"),)], schema) + result = df_dec.agg(count_udaf(df_dec.amount)) + assertDataFrameEqual(result, [Row(**{"Count(amount)": 2})]) + + def test_udaf_complex_types(self): + """Test UDAF with complex types: array, map, struct, binary.""" + + # Array input - sum all elements + class ArraySumAgg(Aggregator): + @staticmethod + def zero(): + return 0 + + @staticmethod + def reduce(buffer, value): + return buffer + sum(value) if value else buffer + + @staticmethod + def merge(buffer1, buffer2): + return buffer1 + buffer2 + + @staticmethod + def finish(reduction): + return reduction + + schema = StructType([StructField("nums", ArrayType(IntegerType()), True)]) + df_arr = self.spark.createDataFrame([([1, 2, 3],), ([4, 5],)], schema) + arr_udaf = udaf(ArraySumAgg(), "bigint", "ArraySum") + result = df_arr.agg(arr_udaf(df_arr.nums)) + assertDataFrameEqual(result, [Row(**{"ArraySum(nums)": 15})]) + + # Map input - count entries + class MapCountAgg(Aggregator): + @staticmethod + def zero(): + return 0 + + @staticmethod + def reduce(buffer, value): + return buffer + len(value) if value else buffer + + @staticmethod + def merge(buffer1, buffer2): + return buffer1 + buffer2 + + @staticmethod + def finish(reduction): + return reduction + + schema = StructType([StructField("kv", MapType(StringType(), IntegerType()), True)]) + df_map = self.spark.createDataFrame([({"a": 1, "b": 2},), ({"c": 3},)], schema) + map_udaf = udaf(MapCountAgg(), "bigint", "MapCount") + result = df_map.agg(map_udaf(df_map.kv)) + assertDataFrameEqual(result, [Row(**{"MapCount(kv)": 3})]) + + # Struct input - average age + class StructAvgAgg(Aggregator): + @staticmethod + def zero(): + return (0, 0) + + @staticmethod + def reduce(buffer, value): + if value and value["age"]: + return (buffer[0] + value["age"], buffer[1] + 1) + return buffer + + @staticmethod + def merge(buffer1, buffer2): + return (buffer1[0] + buffer2[0], buffer1[1] + buffer2[1]) + + @staticmethod + def finish(reduction): + return reduction[0] / reduction[1] if reduction[1] else None + + schema = StructType( + [ + StructField( + "person", + StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ), + True, + ) + ] + ) + df_struct = self.spark.createDataFrame( + [({"name": "Alice", "age": 30},), ({"name": "Bob", "age": 40},)], schema + ) + struct_udaf = udaf(StructAvgAgg(), "double", "AvgAge") + result = df_struct.agg(struct_udaf(df_struct.person)) + assertDataFrameEqual(result, [Row(**{"AvgAge(person)": 35.0})]) + + # Binary input - total length + class BinaryLenAgg(Aggregator): + @staticmethod + def zero(): + return 0 + + @staticmethod + def reduce(buffer, value): + return buffer + len(value) if value else buffer + + @staticmethod + def merge(buffer1, buffer2): + return buffer1 + buffer2 + + @staticmethod + def finish(reduction): + return reduction + + schema = StructType([StructField("data", BinaryType(), True)]) + df_bin = self.spark.createDataFrame([(bytearray(b"abc"),), (bytearray(b"de"),)], schema) + bin_udaf = udaf(BinaryLenAgg(), "bigint", "BinLen") + result = df_bin.agg(bin_udaf(df_bin.data)) + assertDataFrameEqual(result, [Row(**{"BinLen(data)": 5})]) + + # ============ Complex Aggregation Tests ============ + + def test_udaf_statistical_and_special(self): + """Test statistical aggregations and special return types.""" + + # Standard deviation + class StdDevAgg(Aggregator): + @staticmethod + def zero(): + return (0, 0.0, 0.0) + + @staticmethod + def reduce(buffer, value): + if value is not None: + return (buffer[0] + 1, buffer[1] + value, buffer[2] + value * value) + return buffer + + @staticmethod + def merge(buffer1, buffer2): + return ( + buffer1[0] + buffer2[0], + buffer1[1] + buffer2[1], + buffer1[2] + buffer2[2], + ) + + @staticmethod + def finish(reduction): + count, total, sq_total = reduction + if count < 2: + return None + mean = total / count + variance = (sq_total / count) - (mean * mean) + return variance**0.5 + + df = self.spark.createDataFrame( + [(2.0,), (4.0,), (4.0,), (4.0,), (5.0,), (5.0,), (7.0,), (9.0,)], ["value"] + ) + stddev_udaf = udaf(StdDevAgg(), "double", "StdDev") + result = df.agg(stddev_udaf(df.value)) + collected = result.collect()[0][0] + self.assertAlmostEqual(collected, 2.0, places=5) + + # Mode (uses Counter buffer) + class ModeAgg(Aggregator): + @staticmethod + def zero(): + return Counter() + + @staticmethod + def reduce(buffer, value): + if value is not None: + buffer = Counter(buffer) + buffer[value] += 1 + return buffer + + @staticmethod + def merge(buffer1, buffer2): + return Counter(buffer1) + Counter(buffer2) + + @staticmethod + def finish(reduction): + return reduction.most_common(1)[0][0] if reduction else None + + df_int = self.spark.createDataFrame([(1,), (2,), (2,), (3,), (3,), (3,)], ["value"]) + mode_udaf = udaf(ModeAgg(), "bigint", "Mode") + result = df_int.agg(mode_udaf(df_int.value)) + assertDataFrameEqual(result, [Row(**{"Mode(value)": 3})]) + + # Return Array type (Top-3) + class Top3Agg(Aggregator): + @staticmethod + def zero(): + return [] + + @staticmethod + def reduce(buffer, value): + if value is not None: + buffer = list(buffer) + buffer.append(value) + buffer.sort(reverse=True) + return buffer[:3] + return buffer + + @staticmethod + def merge(buffer1, buffer2): + combined = list(buffer1) + list(buffer2) + combined.sort(reverse=True) + return combined[:3] + + @staticmethod + def finish(reduction): + return reduction + + df = self.spark.createDataFrame([(5,), (1,), (9,), (3,), (7,), (2,)], ["value"]) + top3_udaf = udaf(Top3Agg(), ArrayType(LongType()), "Top3") + result = df.agg(top3_udaf(df.value)) + collected = result.collect()[0][0] + self.assertEqual(sorted(collected, reverse=True), [9, 7, 5]) + + # ============ Validation Tests ============ + + def test_udaf_creation_and_interface(self): + """Test UDAF creation, interface validation, and column attributes.""" + # Interface + self.assertTrue(hasattr(Aggregator, "zero")) + self.assertTrue(hasattr(Aggregator, "reduce")) + self.assertTrue(hasattr(Aggregator, "merge")) + self.assertTrue(hasattr(Aggregator, "finish")) + + # Creation + class SumAgg(Aggregator): + @staticmethod + def zero(): + return 0 + + @staticmethod + def reduce(buffer, value): + return buffer + (value or 0) + + @staticmethod + def merge(buffer1, buffer2): + return buffer1 + buffer2 + + @staticmethod + def finish(reduction): + return reduction + + sum_udaf = udaf(SumAgg(), "bigint", "MySum") + self.assertIsInstance(sum_udaf, UserDefinedAggregateFunction) + self.assertEqual(sum_udaf._name, "MySum") + + # Column attributes + df = self.spark.createDataFrame([(1,)], ["value"]) + col = sum_udaf(df.value) + self.assertTrue(hasattr(col, "_udaf_func")) + self.assertTrue(hasattr(col, "_udaf_col")) + + def test_udaf_invalid_inputs(self): + """Test error handling for invalid inputs.""" + + # Invalid aggregator (missing required methods) + class MissingMethods: + @staticmethod + def zero(): + return 0 + + # Missing reduce, merge, finish + + with self.assertRaises(PySparkTypeError): + udaf(MissingMethods(), "bigint") + + # Non-static method raises error + class NonStaticZero(Aggregator): + def zero(self): # Missing @staticmethod + return 0 + + @staticmethod + def reduce(buffer, value): + return buffer + + @staticmethod + def merge(buffer1, buffer2): + return buffer1 + + @staticmethod + def finish(reduction): + return reduction + + with self.assertRaises(PySparkTypeError) as ctx: + udaf(NonStaticZero(), "bigint") + self.assertIn("NOT_CALLABLE", str(ctx.exception)) + + def test_udaf_unsupported_operations(self): + """Test unsupported operations raise appropriate errors.""" + + class SumAgg(Aggregator): + @staticmethod + def zero(): + return 0 + + @staticmethod + def reduce(buffer, value): + return buffer + (value or 0) + + @staticmethod + def merge(buffer1, buffer2): + return buffer1 + buffer2 + + @staticmethod + def finish(reduction): + return reduction + + df = self.spark.createDataFrame([(1, 2)], ["a", "b"]) + sum_udaf = udaf(SumAgg(), "bigint", "MySum") + + # Multiple UDAFs not supported + with self.assertRaises(PySparkNotImplementedError): + df.agg(sum_udaf(df.a), sum_udaf(df.b)) + + # Mixed UDAF with other agg not supported + from pyspark.sql.functions import min as spark_min + + with self.assertRaises(PySparkNotImplementedError): + df.agg(sum_udaf(df.a), spark_min(df.b)) + + +class UDAFTests(UDAFTestsMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.testing.utils import have_pandas, have_pyarrow + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/udaf.py b/python/pyspark/sql/udaf.py new file mode 100644 index 000000000000..cb210275eeb1 --- /dev/null +++ b/python/pyspark/sql/udaf.py @@ -0,0 +1,897 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +User-defined aggregate function related classes and functions +""" +from typing import Any, TYPE_CHECKING, Optional, List, Iterator, Tuple + +from pyspark.sql.column import Column +from pyspark.sql.types import ( + DataType, + _parse_datatype_string, +) +from pyspark.errors import PySparkTypeError, PySparkNotImplementedError + +if TYPE_CHECKING: + from pyspark.sql._typing import DataTypeOrString, ColumnOrName + from pyspark.sql.dataframe import DataFrame + +__all__ = [ + "Aggregator", + "UserDefinedAggregateFunction", + "udaf", +] + + +class Aggregator: + """ + Base class for user-defined aggregations. + + This class defines the interface for implementing user-defined aggregate functions (UDAFs) + in Python. Users should subclass this class and implement the required methods. + + All methods (zero, reduce, merge, finish) MUST be decorated with @staticmethod. + This ensures the aggregator can be properly serialized and executed across Spark workers. + + .. versionadded:: 4.2.0 + + Examples + -------- + >>> class MySum(Aggregator): + ... @staticmethod + ... def zero(): + ... return 0 + ... @staticmethod + ... def reduce(buffer, value): + ... return buffer + value + ... @staticmethod + ... def merge(buffer1, buffer2): + ... return buffer1 + buffer2 + ... @staticmethod + ... def finish(reduction): + ... return reduction + """ + + @staticmethod + def zero() -> Any: + """ + A zero value for this aggregation. Should satisfy the property that any b + zero = b. + + Must be decorated with @staticmethod. + + Returns + ------- + Any + The zero value for the aggregation buffer. + """ + raise NotImplementedError + + @staticmethod + def reduce(buffer: Any, value: Any) -> Any: + """ + Combine an input value into the current intermediate value. + + For performance, the function may modify `buffer` and return it instead of + constructing a new object. + + Must be decorated with @staticmethod. + + Parameters + ---------- + buffer : Any + The current intermediate value (buffer). + value : Any + The input value to aggregate. + + Returns + ------- + Any + The updated buffer. + """ + raise NotImplementedError + + @staticmethod + def merge(buffer1: Any, buffer2: Any) -> Any: + """ + Merge two intermediate values. + + Must be decorated with @staticmethod. + + Parameters + ---------- + buffer1 : Any + The first intermediate value. + buffer2 : Any + The second intermediate value. + + Returns + ------- + Any + The merged intermediate value. + """ + raise NotImplementedError + + @staticmethod + def finish(reduction: Any) -> Any: + """ + Transform the output of the reduction. + + Must be decorated with @staticmethod. + + Parameters + ---------- + reduction : Any + The final reduction result. + + Returns + ------- + Any + The final output value. + """ + raise NotImplementedError + + +def _validate_aggregator_methods(aggregator: Aggregator) -> None: + """ + Validate that all required Aggregator methods are decorated with @staticmethod. + + Parameters + ---------- + aggregator : Aggregator + The aggregator instance to validate. + + Raises + ------ + PySparkTypeError + If any required method is not a static method. + """ + required_methods = ["zero", "reduce", "merge", "finish"] + aggregator_class = type(aggregator) + + for method_name in required_methods: + # Check if the method exists on the class (not just inherited from Aggregator base) + if not hasattr(aggregator_class, method_name): + raise PySparkTypeError( + errorClass="NOT_CALLABLE", + messageParameters={ + "arg_name": f"aggregator.{method_name}", + "arg_type": "missing", + }, + ) + + # Get the method from the class definition (not the instance) + class_attr = getattr(aggregator_class, method_name) + + # Check if it's a staticmethod by looking at the class __dict__ + # (methods bound to instances lose their staticmethod wrapper) + if method_name in aggregator_class.__dict__: + raw_method = aggregator_class.__dict__[method_name] + if not isinstance(raw_method, staticmethod): + raise PySparkTypeError( + errorClass="NOT_CALLABLE", + messageParameters={ + "arg_name": f"aggregator.{method_name}", + "arg_type": f"non-static method (must use @staticmethod decorator)", + }, + ) + + +class UserDefinedAggregateFunction: + """ + User-defined aggregate function wrapper for Python Aggregator. + + This class wraps an Aggregator instance and provides the functionality to use it + as an aggregate function in Spark SQL. The implementation uses mapInArrow and + applyInArrow to perform partial aggregation and final aggregation. + + .. versionadded:: 4.2.0 + """ + + def __init__( + self, + aggregator: Aggregator, + returnType: "DataTypeOrString", + name: Optional[str] = None, + ): + if not isinstance(aggregator, Aggregator): + raise PySparkTypeError( + errorClass="NOT_CALLABLE", + messageParameters={ + "arg_name": "aggregator", + "arg_type": type(aggregator).__name__, + }, + ) + + if not isinstance(returnType, (DataType, str)): + raise PySparkTypeError( + errorClass="NOT_DATATYPE_OR_STR", + messageParameters={ + "arg_name": "returnType", + "arg_type": type(returnType).__name__, + }, + ) + + # Validate that all required methods are static methods + _validate_aggregator_methods(aggregator) + + self.aggregator = aggregator + self._returnType = returnType + self._name = name or ( + aggregator.__class__.__name__ + if hasattr(aggregator, "__class__") + else "UserDefinedAggregateFunction" + ) + # Serialize aggregator for use in Arrow functions + # Use cloudpickle to ensure proper serialization of classes + try: + import cloudpickle + except ImportError: + import pickle as cloudpickle + self._serialized_aggregator = cloudpickle.dumps(aggregator) + + @property + def returnType(self) -> DataType: + """Get the return type of this UDAF.""" + if isinstance(self._returnType, DataType): + return self._returnType + else: + return _parse_datatype_string(self._returnType) + + def __call__(self, *args: "ColumnOrName") -> Column: + """ + Apply this UDAF to the given columns. + + This creates a Column expression that can be used in DataFrame operations. + The actual aggregation is performed using mapInArrow and applyInArrow. + + Parameters + ---------- + *args : ColumnOrName + The columns to aggregate. Currently supports a single column. + + Returns + ------- + Column + A Column representing the aggregation result. + + Notes + ----- + This implementation uses mapInArrow and applyInArrow internally to perform + the aggregation. The approach follows: + 1. mapInArrow: Performs partial aggregation (reduce) on each partition + 2. groupBy: Groups partial results by a random key (range based on + spark.sql.shuffle.partitions config or DataFrame partition count) + 3. applyInArrow: Merges partial results and produces final result + + Examples + -------- + >>> class MySum(Aggregator): + ... @staticmethod + ... def zero(): + ... return 0 + ... @staticmethod + ... def reduce(buffer, value): + ... return buffer + value + ... @staticmethod + ... def merge(buffer1, buffer2): + ... return buffer1 + buffer2 + ... @staticmethod + ... def finish(reduction): + ... return reduction + ... + >>> sum_udaf = udaf(MySum(), "bigint") + >>> df = spark.createDataFrame([(1,), (2,), (3,)], ["value"]) + >>> df.agg(sum_udaf(df.value)).show() + +------------+ + |MySum(value)| + +------------+ + | 6| + +------------+ + """ + # Return a Column with UDAF metadata attached as an attribute + # This allows GroupedData.agg() to detect and handle UDAF columns + # without introducing a special Column type + from pyspark.sql.classic.column import Column as ClassicColumn + from pyspark.sql.functions import col as spark_col + + col_expr = args[0] + if isinstance(col_expr, str): + col_expr = spark_col(col_expr) + + # Create a Column and attach UDAF information as an attribute + # This is similar to how pandas UDF works - the Column contains metadata + # that can be checked in agg() without requiring a special Column type + result_col = ClassicColumn(col_expr._jc) # type: ignore[attr-defined] + # Attach UDAF metadata as attributes (not a special type) + result_col._udaf_func = self # type: ignore[attr-defined] + result_col._udaf_col = col_expr # type: ignore[attr-defined] + return result_col + + +def _extract_column_name(col_expr: "ColumnOrName") -> tuple[Column, str]: + """Extract column name from Column or string, return (Column, name).""" + from pyspark.sql.functions import col as spark_col + + if isinstance(col_expr, str): + return spark_col(col_expr), col_expr + else: + # Extract column name from expression string (e.g., "value" from "Column<'value'>") + col_name_str = col_expr._jc.toString() if hasattr(col_expr, "_jc") else str(col_expr) + col_name = col_name_str.split("'")[1] if "'" in col_name_str else "value" + return col_expr, col_name + + +def _extract_grouping_column_names(grouping_cols: List[Column]) -> List[str]: + """Extract grouping column names from Column objects.""" + grouping_col_names = [] + for gc in grouping_cols: + gc_str = gc._jc.toString() if hasattr(gc, "_jc") else str(gc) + if "'" in gc_str: + gc_name = gc_str.split("'")[1] + else: + # Fallback: use the string representation + gc_name = gc_str.split("(")[0].strip() if "(" in gc_str else gc_str.strip() + grouping_col_names.append(gc_name) + return grouping_col_names + + +def _extract_grouping_columns_from_jvm(jgd: Any) -> List[Column]: + """ + Extract grouping columns from GroupedData's JVM representation. + + Parameters + ---------- + jgd : JavaObject + The JVM GroupedData object. + + Returns + ------- + List[Column] + List of grouping column expressions, empty if no grouping or parsing fails. + """ + from pyspark.sql.functions import col as spark_col + import re + + try: + jvm_string = jgd.toString() + # Format: "RelationalGroupedDataset: [grouping expressions: [col1, col2], ...]" + match = re.search(r"grouping expressions:\s*\[([^\]]+)\]", jvm_string) + if match: + grouping_exprs_str = match.group(1) + grouping_col_names = [name.strip() for name in grouping_exprs_str.split(",")] + return [spark_col(name.strip()) for name in grouping_col_names] + except Exception: + # If parsing fails, assume no grouping + pass + return [] + + +def _apply_udaf_via_catalyst( + df: "DataFrame", + jgd: Any, + udaf_func: "UserDefinedAggregateFunction", + udaf_col_expr: Column, + grouping_cols: Optional[List[Column]], +) -> "DataFrame": + """ + Apply UDAF via the Catalyst optimizer path (Scala). + + This creates three Arrow UDFs and passes them to the Scala + pythonAggregatorUDAF method, which uses the Catalyst optimizer + to execute the aggregation. + + Parameters + ---------- + df : DataFrame + The original DataFrame. + jgd : JavaObject + The JVM GroupedData object. + udaf_func : UserDefinedAggregateFunction + The UDAF function. + udaf_col_expr : Column + The column expression to aggregate. + grouping_cols : Optional[List[Column]] + The grouping columns if this is a grouped aggregation. + + Returns + ------- + DataFrame + Aggregated result DataFrame + """ + from pyspark.sql import DataFrame + from pyspark.sql.pandas.functions import pandas_udf + from pyspark.util import PythonEvalType + from pyspark.sql.types import StructType, StructField, LongType, BinaryType + + # Get aggregator info + serialized_aggregator = udaf_func._serialized_aggregator + return_type = udaf_func.returnType + has_grouping = grouping_cols is not None and len(grouping_cols) > 0 + grouping_col_names = _extract_grouping_column_names(grouping_cols) if has_grouping else [] + col_expr, col_name = _extract_column_name(udaf_col_expr) + + # Get max key for random grouping + max_key = _get_max_key_for_random_grouping(df) + + # Create the three phase functions + reduce_func = _create_reduce_func(serialized_aggregator, max_key, len(grouping_col_names)) + merge_func = _create_merge_func(serialized_aggregator) + + return_type_str = ( + return_type.simpleString() if hasattr(return_type, "simpleString") else str(return_type) + ) + result_col_name_safe = f"{udaf_func._name}_{col_name}".replace("(", "_").replace(")", "_") + final_merge_func = _create_final_merge_func( + serialized_aggregator, + return_type, + has_grouping, + grouping_col_names, + udaf_func._name, + col_name, + ) + + # Create Arrow UDFs for each phase + reduce_schema = StructType( + [ + StructField("key", LongType(), False), + StructField("buffer", BinaryType(), True), + ] + ) + reduce_udf = pandas_udf( + reduce_func, + returnType=reduce_schema, + functionType=PythonEvalType.SQL_MAP_ARROW_ITER_UDF, + ) + + merge_schema = StructType([StructField("buffer", BinaryType(), True)]) + merge_udf = pandas_udf( + merge_func, + returnType=merge_schema, + functionType=PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF, + ) + + final_schema_str = _build_result_schema( + has_grouping, grouping_col_names, result_col_name_safe, return_type_str + ) + final_udf = pandas_udf( + final_merge_func, + returnType=final_schema_str, + functionType=PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF, + ) + + # Apply UDFs to the columns in correct order: grouping columns first, then value column + # This ensures reduce_func receives columns in the expected order: + # [group_col_0, group_col_1, ..., group_col_n, value_col] + ordered_cols = [df[c] for c in grouping_col_names] + [col_expr] + reduce_udf_col = reduce_udf(*ordered_cols) + merge_udf_col = merge_udf(*ordered_cols) + final_udf_col = final_udf(*ordered_cols) + + # Get result type as JSON string + spark_session = df.sparkSession + result_type_json = return_type.json() + + # Call the Scala pythonAggregatorUDAF method + jdf = jgd.pythonAggregatorUDAF( + reduce_udf_col._jc, + merge_udf_col._jc, + final_udf_col._jc, + result_type_json, + ) + + result_df = DataFrame(jdf, spark_session) + + # Rename result column to match expected format + from pyspark.sql.functions import col as spark_col + + result_col_name = f"{udaf_func._name}({col_name})" + if has_grouping: + select_exprs = [spark_col(gc_name) for gc_name in grouping_col_names] + select_exprs.append(spark_col("result").alias(result_col_name)) + return result_df.select(*select_exprs) + else: + return result_df.select(spark_col("result").alias(result_col_name)) + + +def _handle_udaf_aggregation_in_grouped_data( + df: "DataFrame", + jgd: Any, + exprs: Tuple[Column, ...], + udaf_cols: List[Column], +) -> "DataFrame": + """ + Handle UDAF aggregation in GroupedData.agg() method. + + Parameters + ---------- + df : DataFrame + The original DataFrame. + jgd : JavaObject + The JVM GroupedData object. + exprs : Tuple[Column, ...] + All expression columns passed to agg() + udaf_cols : List[Column] + Columns that have _udaf_func attribute + + Returns + ------- + DataFrame + Aggregated result DataFrame + """ + # Validate UDAF usage constraints + if len(udaf_cols) > 1: + raise PySparkNotImplementedError( + errorClass="NOT_IMPLEMENTED", + messageParameters={ + "feature": "Multiple UDAFs in a single agg() call. Currently only one UDAF is supported." + }, + ) + if len(exprs) > 1: + raise PySparkNotImplementedError( + errorClass="NOT_IMPLEMENTED", + messageParameters={ + "feature": "Mixing UDAF with other aggregate functions. Currently only single UDAF is supported." + }, + ) + + # Extract UDAF information + udaf_col = udaf_cols[0] + udaf_func = udaf_col._udaf_func # type: ignore[attr-defined] + udaf_col_expr = udaf_col._udaf_col # type: ignore[attr-defined] + + # Get grouping columns + grouping_cols = _extract_grouping_columns_from_jvm(jgd) + + # Use Catalyst optimizer path (via Scala pythonAggregatorUDAF) + return _apply_udaf_via_catalyst( + df, + jgd, + udaf_func, + udaf_col_expr, + grouping_cols if grouping_cols else None, + ) + + +def _get_max_key_for_random_grouping(df: "DataFrame") -> int: + """Get max key for random grouping based on Spark config or partition count.""" + try: + spark_session = df.sparkSession + shuffle_partitions = int(spark_session.conf.get("spark.sql.shuffle.partitions", "200")) + num_partitions = df.rdd.getNumPartitions() + return max(shuffle_partitions, num_partitions, 1) + except Exception: + return 200 + + +def _convert_results_to_arrow(results: List[Any], return_type: DataType) -> Any: + """Convert a list of result values to Arrow array based on return type.""" + import pyarrow as pa + from pyspark.sql.pandas.types import to_arrow_type + from pyspark.sql.conversion import LocalDataToArrowConversion + + # Use existing conversion utilities for accurate type handling + arrow_type = to_arrow_type(return_type) + converter = LocalDataToArrowConversion._create_converter(return_type, nullable=True) + + if converter is not None: + converted_results = [converter(r) for r in results] + else: + converted_results = results + + return pa.array(converted_results, type=arrow_type) + + +def _create_reduce_func( + serialized_aggregator: bytes, + max_key: int, + num_grouping_cols: int, +): + """Create reduce function for mapInArrow.""" + + def reduce_func(iterator): + import pyarrow as pa + import cloudpickle + import random + + agg = cloudpickle.loads(serialized_aggregator) + group_buffers = {} + value_col_idx = num_grouping_cols + + for batch in iterator: + if ( + isinstance(batch, pa.RecordBatch) + and batch.num_columns > value_col_idx + and batch.num_rows > 0 + ): + value_col = batch.column(value_col_idx) + + for row_idx in range(batch.num_rows): + # Extract grouping key (None for non-grouped case, tuple for grouped case) + grouping_key = ( + tuple([batch.column(i)[row_idx].as_py() for i in range(num_grouping_cols)]) + if num_grouping_cols > 0 + else None + ) + + value = value_col[row_idx].as_py() + + if grouping_key not in group_buffers: + group_buffers[grouping_key] = agg.zero() + + if value is not None: + group_buffers[grouping_key] = agg.reduce(group_buffers[grouping_key], value) + + # Handle empty DataFrame case for non-grouped aggregation + if not group_buffers and num_grouping_cols == 0: + group_buffers[None] = agg.zero() + + # Yield one record per group with random key (always serialize as (grouping_key, buffer)) + for grouping_key, buffer in group_buffers.items(): + key = random.randint(0, max_key) + grouping_key_bytes = cloudpickle.dumps((grouping_key, buffer)) + yield pa.RecordBatch.from_arrays( + [ + pa.array([key], type=pa.int64()), + pa.array([grouping_key_bytes], type=pa.binary()), + ], + ["key", "buffer"], + ) + + return reduce_func + + +def _create_merge_func(serialized_aggregator: bytes): + """Create merge function for applyInArrow using iterator API.""" + import pyarrow as pa + + def merge_func(batches: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: + """Iterator-based merge function that processes batches one by one. + + Note: When called via FlatMapGroupsInArrow, batches contain only the value columns + (key is handled separately by the executor). The buffer column is at index 0. + When called via Python-only path (applyInArrow), batches contain all columns + (key at 0, buffer at 1). + """ + import cloudpickle + + agg = cloudpickle.loads(serialized_aggregator) + group_buffers = {} + + for batch in batches: + if isinstance(batch, pa.RecordBatch) and batch.num_columns > 0 and batch.num_rows > 0: + # Buffer is at last column position (handles both 1-column and 2-column cases) + buffer_col = batch.column(batch.num_columns - 1) + for i in range(batch.num_rows): + buffer_bytes = buffer_col[i].as_py() + grouping_key, buffer_value = cloudpickle.loads(buffer_bytes) + + if grouping_key not in group_buffers: + group_buffers[grouping_key] = buffer_value + else: + group_buffers[grouping_key] = agg.merge( + group_buffers[grouping_key], buffer_value + ) + + # Yield merged buffers (always serialize as (grouping_key, buffer)) + for grouping_key, buffer in group_buffers.items(): + grouping_key_bytes = cloudpickle.dumps((grouping_key, buffer)) + yield pa.RecordBatch.from_arrays( + [pa.array([grouping_key_bytes], type=pa.binary())], ["buffer"] + ) + + return merge_func + + +def _create_final_merge_func( + serialized_aggregator: bytes, + return_type: DataType, + has_grouping: bool, + grouping_col_names: List[str], + udaf_func_name: str, + col_name: str, +): + """Create final merge function for applyInArrow using iterator API.""" + import pyarrow as pa + import cloudpickle + + # Serialize return_type for use in worker + serialized_return_type = cloudpickle.dumps(return_type) + + def _convert_results_to_arrow_local(results: List[Any], serialized_dt: bytes) -> Any: + """Convert a list of result values to Arrow array based on return type.""" + from pyspark.sql.pandas.types import to_arrow_type + from pyspark.sql.conversion import LocalDataToArrowConversion + + # Use DataType object for accurate conversion + dt = cloudpickle.loads(serialized_dt) + arrow_type = to_arrow_type(dt) + converter = LocalDataToArrowConversion._create_converter(dt, nullable=True) + + if converter is not None: + converted_results = [converter(r) for r in results] + else: + converted_results = results + + return pa.array(converted_results, type=arrow_type) + + def final_merge_func(batches: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: + """Iterator-based final merge function that processes batches one by one. + + Note: When called via FlatMapGroupsInArrow, batches contain only the value columns + (key is handled separately by the executor). The buffer column is at index 0. + When called via Python-only path (applyInArrow), batches contain all columns + (final_key at 0, buffer at 1). + """ + import cloudpickle + + agg = cloudpickle.loads(serialized_aggregator) + + # Unified logic: always deserialize as (grouping_key, buffer) + # For non-grouped case, grouping_key is None + group_results = {} + has_data = False + + for batch in batches: + if isinstance(batch, pa.RecordBatch) and batch.num_columns > 0: + if batch.num_rows > 0: + has_data = True + # Buffer is at last column position (handles both 1-column and 2-column cases) + buffer_col = batch.column(batch.num_columns - 1) + for i in range(batch.num_rows): + buffer_bytes = buffer_col[i].as_py() + # Always deserialize as (grouping_key, buffer) + grouping_key, buffer_value = cloudpickle.loads(buffer_bytes) + + if grouping_key not in group_results: + group_results[grouping_key] = buffer_value + else: + group_results[grouping_key] = agg.merge( + group_results[grouping_key], buffer_value + ) + + # Finish each group and collect all results + all_grouping_vals = [] + all_results = [] + + for grouping_key, buffer in group_results.items(): + all_grouping_vals.append(grouping_key) + all_results.append(agg.finish(buffer)) + + # Handle empty case: for non-grouped, use zero() result; for grouped, keep empty + # Unified handling: both cases go through the same array building logic + if not has_data and not has_grouping: + # Non-grouped empty case: return zero() result + all_results = [agg.finish(agg.zero())] + + # Build result arrays - unified for both grouped and non-grouped cases + result_arrays = [] + result_names = [] + + # Add grouping columns (empty list for non-grouped case, so loop won't execute) + for i in range(len(grouping_col_names)): + col_values = [grouping_vals[i] for grouping_vals in all_grouping_vals] + if col_values and isinstance(col_values[0], str): + result_arrays.append(pa.array(col_values, type=pa.string())) + else: + result_arrays.append(pa.array(col_values, type=pa.int64())) + result_names.append(grouping_col_names[i]) + + # Add result column (always add, even if empty for grouped case) + result_col_name_safe = f"{udaf_func_name}_{col_name}".replace("(", "_").replace(")", "_") + result_arrays.append(_convert_results_to_arrow_local(all_results, serialized_return_type)) + result_names.append(result_col_name_safe if has_grouping else "result") + + yield pa.RecordBatch.from_arrays(result_arrays, result_names) + + return final_merge_func + + +def _build_result_schema( + has_grouping: bool, + grouping_col_names: List[str], + result_col_name_safe: str, + return_type_str: str, +) -> str: + """Build schema string for final merge result.""" + if has_grouping: + schema_parts = [f"{gc_name} string" for gc_name in grouping_col_names] + schema_parts.append(f"{result_col_name_safe} {return_type_str}") + return ", ".join(schema_parts) + else: + return f"result {return_type_str}" + + +def udaf( + aggregator: Aggregator, + returnType: "DataTypeOrString", + name: Optional[str] = None, +) -> UserDefinedAggregateFunction: + """ + Creates a user-defined aggregate function (UDAF) from an Aggregator instance. + + .. versionadded:: 4.2.0 + + Parameters + ---------- + aggregator : Aggregator + An instance of Aggregator that implements the aggregation logic. + returnType : :class:`pyspark.sql.types.DataType` or str + The return type of the UDAF. Can be either a DataType object or a DDL-formatted string. + name : str, optional + Optional name for the UDAF. If not provided, uses the aggregator class name. + + Returns + ------- + UserDefinedAggregateFunction + A UserDefinedAggregateFunction that can be used in DataFrame operations. + + Examples + -------- + >>> class MySum(Aggregator): + ... @staticmethod + ... def zero(): + ... return 0 + ... @staticmethod + ... def reduce(buffer, value): + ... return buffer + value + ... @staticmethod + ... def merge(buffer1, buffer2): + ... return buffer1 + buffer2 + ... @staticmethod + ... def finish(reduction): + ... return reduction + ... + >>> sum_udaf = udaf(MySum(), "bigint") + >>> df = spark.createDataFrame([(1,), (2,), (3,)], ["value"]) + >>> df.agg(sum_udaf(df.value)).show() + +------------+ + |MySum(value)| + +------------+ + | 6| + +------------+ + """ + return UserDefinedAggregateFunction(aggregator, returnType, name) + + +def _test() -> None: + import doctest + import sys + from pyspark.sql import SparkSession + from pyspark.testing.utils import have_pandas, have_pyarrow + import pyspark.sql.udaf + + globs = pyspark.sql.udaf.__dict__.copy() + + if not have_pandas or not have_pyarrow: + del pyspark.sql.udaf.udaf.__doc__ + del pyspark.sql.udaf.UserDefinedAggregateFunction.__call__.__doc__ + + spark = SparkSession.builder.master("local[4]").appName("sql.udaf tests").getOrCreate() + globs["spark"] = spark + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.udaf, + globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF, + ) + spark.stop() + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a2dced57c715..3b9a50207eb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -331,6 +331,7 @@ abstract class Optimizer(catalogManager: CatalogManager) ReplaceCurrentLike(catalogManager), SpecialDatetimeValues, RewriteAsOfJoin, + RewritePythonAggregatorUDAF, EvalInlineTables, ReplaceTranspose, RewriteCollationJoin diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewritePythonAggregatorUDAF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewritePythonAggregatorUDAF.scala new file mode 100644 index 000000000000..59897ee7c1e3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewritePythonAggregatorUDAF.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.api.python.PythonFunction +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types._ + +/** + * Rewrites [[PythonAggregatorUDAF]] logical operator using a combination of + * MapInArrow, Aggregate, and FlatMapGroupsInArrow operators. + * + * This implements a three-phase aggregation pattern: + * 1. Partial aggregation (MapInArrow): Applies reduce() on each partition, outputs + * (random_key, buffer) pairs + * 2. Intermediate merge (FlatMapGroupsInArrow): Groups by random key, applies merge() + * 3. Final merge (FlatMapGroupsInArrow): Groups by actual group keys (or single key), + * applies merge() + finish() + * + * The key insight is that we must create NEW PythonUDF expressions for each phase, + * using the correct intermediate attribute references. We extract the PythonFunction + * from the original UDF expressions and create new PythonUDFs with the correct children. + */ +object RewritePythonAggregatorUDAF extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithNewOutput { + case udaf @ PythonAggregatorUDAF( + groupingAttributes, + partialReduceUDF, + mergeUDF, + finalMergeUDF, + resultAttribute, + child) => + + // Extract PythonFunction from each UDF expression + val reduceFunc = extractPythonFunction(partialReduceUDF) + val mergeFunc = extractPythonFunction(mergeUDF) + val finalFunc = extractPythonFunction(finalMergeUDF) + + // Step 1: MapInArrow for partial aggregation + // The reduce UDF takes the child's output and produces (key, buffer) + val mapInArrowOutput = Seq( + AttributeReference("key", LongType, nullable = false)(), + AttributeReference("buffer", BinaryType, nullable = true)() + ) + + // Use the original UDF's children (columns specified by Python in correct order) + // This preserves the ordering: [grouping_cols..., value_col] + val originalReduceChildren = partialReduceUDF.asInstanceOf[PythonUDF].children + val reduceUDFWithCorrectChildren = createPythonUDF( + reduceFunc, + partialReduceUDF, + originalReduceChildren, // Preserve original column ordering from Python + StructType(mapInArrowOutput.map(a => StructField(a.name, a.dataType, a.nullable))) + ) + + val mapInArrow = MapInArrow( + reduceUDFWithCorrectChildren, + mapInArrowOutput, + child, + isBarrier = false, + profile = None + ) + + // Get actual output attributes from mapInArrow + val keyAttr = mapInArrow.output.head + val bufferAttr = mapInArrow.output(1) + + // Step 2: Group by random key and apply merge via FlatMapGroupsInArrow + val mergeOutputAttrs = Seq(AttributeReference("buffer", BinaryType, nullable = true)()) + + // Create merge UDF that takes (key, buffer) and produces (buffer) + val mergeUDFWithCorrectChildren = createPythonUDF( + mergeFunc, + mergeUDF, + Seq(keyAttr, bufferAttr), // Reference the mapInArrow output + StructType(mergeOutputAttrs.map(a => StructField(a.name, a.dataType, a.nullable))) + ) + + val flatMapMerge = FlatMapGroupsInArrow( + Seq(keyAttr), + mergeUDFWithCorrectChildren, + mergeOutputAttrs, + mapInArrow + ) + val mergedBufferAttr = flatMapMerge.output.head + + // Step 3: Add a constant key for final grouping + val finalKeyAlias = Alias(Literal(0L), "final_key")() + val projectWithFinalKey = Project( + Seq(finalKeyAlias, mergedBufferAttr), + flatMapMerge + ) + val finalKeyAttr = projectWithFinalKey.output.head + val finalBufferAttr = projectWithFinalKey.output(1) + + // Step 4: Group by final key and apply final merge + finish + val finalOutput = if (groupingAttributes.nonEmpty) { + groupingAttributes.map(_.toAttribute) :+ resultAttribute + } else { + Seq(resultAttribute) + } + + // Create final UDF that takes (final_key, buffer) and produces the result + val finalUDFWithCorrectChildren = createPythonUDF( + finalFunc, + finalMergeUDF, + Seq(finalKeyAttr, finalBufferAttr), // Reference the project output + StructType(finalOutput.map(a => StructField(a.name, a.dataType, a.nullable))) + ) + + val flatMapFinal = FlatMapGroupsInArrow( + Seq(finalKeyAttr), + finalUDFWithCorrectChildren, + finalOutput, + projectWithFinalKey + ) + + // Step 5: Project out the final key if it was just for grouping + val result = if (groupingAttributes.isEmpty) { + Project(Seq(resultAttribute), flatMapFinal) + } else { + Project(finalOutput, flatMapFinal) + } + + val attrMapping = udaf.output.zip(result.output) + result -> attrMapping + } + + /** + * Extract PythonFunction from a PythonUDF expression. + */ + private def extractPythonFunction(expr: Expression): PythonFunction = { + expr match { + case udf: PythonUDF => udf.func + case other => + throw new IllegalArgumentException( + s"Expected PythonUDF but got ${other.getClass.getSimpleName}") + } + } + + /** + * Create a new PythonUDF with the given function but different children (attribute references). + */ + private def createPythonUDF( + func: PythonFunction, + originalUDF: Expression, + newChildren: Seq[Expression], + returnType: DataType): PythonUDF = { + originalUDF match { + case udf: PythonUDF => + PythonUDF( + name = udf.name, + func = func, + dataType = returnType, + children = newChildren, + evalType = udf.evalType, + udfDeterministic = udf.udfDeterministic + ) + case other => + throw new IllegalArgumentException( + s"Expected PythonUDF but got ${other.getClass.getSimpleName}") + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index bcfcae2ee16c..d3bf9547806a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -447,3 +447,37 @@ object PythonWorkerLogs extends SQLConfHelper { TableFunctionName -> (info, funcBuilder) } } + +/** + * Aggregation using a Python Aggregator class with zero, reduce, merge, finish methods. + * + * This logical plan represents a user-defined aggregate function (UDAF) implemented in Python + * using the Aggregator pattern. The Aggregator is serialized and sent to Python workers. + * + * The execution pattern follows a three-phase aggregation: + * 1. Partial aggregation (MapInArrow): Each partition applies reduce() to local data + * 2. Intermediate merge (FlatMapGroupsInArrow): Merge partial results by random key + * 3. Final merge (FlatMapGroupsInArrow): Final merge by group keys + finish() + * + * @param groupingAttributes attributes used for grouping + * @param partialReduceUDF PythonUDF expression for partial aggregation (reduce phase) + * @param mergeUDF PythonUDF expression for intermediate merging + * @param finalMergeUDF PythonUDF expression for final merge + finish + * @param resultAttribute the output attribute with the aggregation result + * @param child the child logical plan + */ +case class PythonAggregatorUDAF( + groupingAttributes: Seq[Attribute], + partialReduceUDF: Expression, + mergeUDF: Expression, + finalMergeUDF: Expression, + resultAttribute: Attribute, + child: LogicalPlan) extends UnaryNode { + + override def output: Seq[Attribute] = groupingAttributes :+ resultAttribute + + override val producedAttributes: AttributeSet = AttributeSet(output) + + override protected def withNewChildInternal(newChild: LogicalPlan): PythonAggregatorUDAF = + copy(child = newChild) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewritePythonAggregatorUDAFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewritePythonAggregatorUDAFSuite.scala new file mode 100644 index 000000000000..e89e8fcf2034 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewritePythonAggregatorUDAFSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ + +class RewritePythonAggregatorUDAFSuite extends PlanTest { + + test("rewrite ungrouped aggregation") { + val child = LocalRelation($"a".int, $"b".int) + + // Create mock UDF expressions (using Literal as placeholder) + val partialReduceUDF = Literal("partial_reduce_udf") + val mergeUDF = Literal("merge_udf") + val finalMergeUDF = Literal("final_merge_udf") + val resultAttr = AttributeReference("result", LongType)() + + val query = PythonAggregatorUDAF( + groupingAttributes = Seq.empty, + partialReduceUDF = partialReduceUDF, + mergeUDF = mergeUDF, + finalMergeUDF = finalMergeUDF, + resultAttribute = resultAttr, + child = child + ) + + val rewritten = RewritePythonAggregatorUDAF(query) + + // Verify that the rewritten plan does not contain PythonAggregatorUDAF + assert(!rewritten.exists(_.isInstanceOf[PythonAggregatorUDAF]), + "PythonAggregatorUDAF should be rewritten") + + // Verify the rewritten plan contains MapInArrow and FlatMapGroupsInArrow + assert(rewritten.exists(_.isInstanceOf[MapInArrow]), + "Rewritten plan should contain MapInArrow") + assert(rewritten.exists(_.isInstanceOf[FlatMapGroupsInArrow]), + "Rewritten plan should contain FlatMapGroupsInArrow") + } + + test("rewrite grouped aggregation") { + val child = LocalRelation($"a".int, $"b".int, $"c".int) + val groupingAttr = child.output.head + + // Create mock UDF expressions (using Literal as placeholder) + val partialReduceUDF = Literal("partial_reduce_udf") + val mergeUDF = Literal("merge_udf") + val finalMergeUDF = Literal("final_merge_udf") + val resultAttr = AttributeReference("result", LongType)() + + val query = PythonAggregatorUDAF( + groupingAttributes = Seq(groupingAttr), + partialReduceUDF = partialReduceUDF, + mergeUDF = mergeUDF, + finalMergeUDF = finalMergeUDF, + resultAttribute = resultAttr, + child = child + ) + + val rewritten = RewritePythonAggregatorUDAF(query) + + // Verify that the rewritten plan does not contain PythonAggregatorUDAF + assert(!rewritten.exists(_.isInstanceOf[PythonAggregatorUDAF]), + "PythonAggregatorUDAF should be rewritten") + + // Verify the rewritten plan contains MapInArrow and FlatMapGroupsInArrow + assert(rewritten.exists(_.isInstanceOf[MapInArrow]), + "Rewritten plan should contain MapInArrow") + assert(rewritten.exists(_.isInstanceOf[FlatMapGroupsInArrow]), + "Rewritten plan should contain FlatMapGroupsInArrow") + + // Verify the output includes the grouping attribute + val outputAttrs = rewritten.output.map(_.name) + assert(outputAttrs.contains("a") || outputAttrs.exists(_.contains("result")), + "Output should contain grouping attribute or result") + } + + test("non-PythonAggregatorUDAF plan unchanged") { + val child = LocalRelation($"a".int, $"b".int) + + val rewritten = RewritePythonAggregatorUDAF(child) + + // Should be unchanged (same object reference since no transformation needed) + assert(rewritten eq child, "Plan without PythonAggregatorUDAF should be unchanged") + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala index bd7b3348b9f0..52059a47c56e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.classic.TypedAggUtils.withInputType import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{IntegerType, NumericType, StructType} +import org.apache.spark.sql.types.{DataType, IntegerType, NumericType, StructType} import org.apache.spark.util.ArrayImplicits._ /** @@ -332,6 +332,59 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + /** + * Applies a Python Aggregator UDAF to the grouped data. + * + * NOTE: This method is currently not used. The Python UDAF implementation uses + * a Python-only path via mapInArrow and applyInArrow. This method and the + * corresponding PythonAggregatorUDAF logical plan are reserved for future + * Catalyst optimizer integration. + * + * The Aggregator pattern uses three phases: + * 1. Partial reduce (mapInArrow): Each partition applies reduce() to local data + * 2. Intermediate merge (flatMapGroupsInArrow): Merge partial results by random key + * 3. Final merge (flatMapGroupsInArrow): Final merge by group keys + finish() + * + * @param partialReduceUDF UDF for the partial reduce phase + * @param mergeUDF UDF for the intermediate merge phase + * @param finalMergeUDF UDF for the final merge + finish phase + * @param resultTypeJson The return type of the UDAF as JSON string + */ + private[sql] def pythonAggregatorUDAF( + partialReduceUDF: Column, + mergeUDF: Column, + finalMergeUDF: Column, + resultTypeJson: String): DataFrame = { + + val resultType = DataType.fromJson(resultTypeJson) + + val groupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + + // Use the UDF's children directly - they already specify the correct columns in order + // The partialReduceUDF contains (grouping_cols..., value_col) as specified by Python + val udfChildren = partialReduceUDF.expr.asInstanceOf[PythonUDF].children + val child = df.logicalPlan + val project = df.sparkSession.sessionState.executePlan( + Project(udfChildren.map(_.asInstanceOf[NamedExpression]), child)).analyzed + val groupingAttributes = project.output.take(groupingNamedExpressions.length) + + val resultAttribute = AttributeReference("result", resultType, nullable = true)() + + val plan = PythonAggregatorUDAF( + groupingAttributes, + partialReduceUDF.expr, + mergeUDF.expr, + finalMergeUDF.expr, + resultAttribute, + project + ) + + Dataset.ofRows(df.sparkSession, plan) + } + /** * Applies a vectorized python user-defined function to each cogrouped data. * The user-defined function defines a transformation: