From 373b7b5e5630a34d2781f56ce2d4a684393870c0 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Wed, 26 Nov 2025 11:38:54 -0800 Subject: [PATCH 1/3] feat: add ArrowStreamAggPandasUDFSerializer --- python/pyspark/sql/pandas/serializers.py | 54 ++++++++++++++++++++++++ python/pyspark/worker.py | 6 ++- 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 12a1f3c288b49..6327332118de6 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1185,6 +1185,60 @@ def __repr__(self): return "ArrowStreamAggArrowUDFSerializer" +# Serializer for SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF +class ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer): + def __init__( + self, + timezone, + safecheck, + assign_cols_by_name, + int_to_decimal_coercion_enabled, + ): + super(ArrowStreamAggPandasUDFSerializer, self).__init__( + timezone=timezone, + safecheck=safecheck, + assign_cols_by_name=False, + df_for_struct=False, + struct_in_pandas="dict", + ndarray_as_list=False, + arrow_cast=True, + input_types=None, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + ) + self._timezone = timezone + self._safecheck = safecheck + self._assign_cols_by_name = assign_cols_by_name + + def load_stream(self, stream): + """ + Deserialize Grouped ArrowRecordBatches and yield as a list of pandas.Series. + """ + import pyarrow as pa + + dataframes_in_group = None + + while dataframes_in_group is None or dataframes_in_group > 0: + dataframes_in_group = read_int(stream) + + if dataframes_in_group == 1: + yield ( + [ + self.arrow_to_pandas(c, i) + for i, c in enumerate(pa.Table.from_batches(ArrowStreamSerializer.load_stream(self, stream)).itercolumns()) + ] + ) + + elif dataframes_in_group != 0: + raise PySparkValueError( + errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", + messageParameters={"dataframes_in_group": str(dataframes_in_group)}, + ) + + def __repr__(self): + return "ArrowStreamAggPandasUDFSerializer" + + +# Serializer for SQL_GROUPED_MAP_PANDAS_UDF, SQL_GROUPED_MAP_PANDAS_ITER_UDF class GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer): def __init__( self, diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 94e3b2728d08f..96aac6083bf27 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -67,6 +67,7 @@ TransformWithStateInPySparkRowSerializer, TransformWithStateInPySparkRowInitStateSerializer, ArrowStreamArrowUDFSerializer, + ArrowStreamAggPandasUDFSerializer, ArrowStreamAggArrowUDFSerializer, ArrowBatchUDFSerializer, ArrowStreamUDTFSerializer, @@ -2721,10 +2722,13 @@ def read_udfs(pickleSer, infile, eval_type): ): ser = ArrowStreamAggArrowUDFSerializer(timezone, True, _assign_cols_by_name, True) elif eval_type in ( - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, ): + ser = ArrowStreamAggPandasUDFSerializer( + timezone, safecheck, _assign_cols_by_name, int_to_decimal_coercion_enabled + ) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: ser = GroupPandasUDFSerializer( timezone, safecheck, _assign_cols_by_name, int_to_decimal_coercion_enabled ) From e02f6b3275b53e7c617ede761a7a862c654e283b Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Wed, 26 Nov 2025 11:39:13 -0800 Subject: [PATCH 2/3] test: add test for window pandas agg multi udf case --- .../tests/pandas/test_pandas_udf_window.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py index 6e1cbdaf73cff..7cfcb29f50c17 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py @@ -190,6 +190,32 @@ def test_multiple_udfs(self): assert_frame_equal(expected1.toPandas(), result1.toPandas()) + def test_multiple_udfs_in_single_projection(self): + """ + Test multiple window aggregate pandas UDFs in a single select/projection. + """ + df = self.data + w = self.unbounded_window + + # Use select() with multiple window UDFs in the same projection + result1 = df.select( + df["id"], + df["v"], + self.pandas_agg_mean_udf(df["v"]).over(w).alias("mean_v"), + self.pandas_agg_max_udf(df["v"]).over(w).alias("max_v"), + self.pandas_agg_min_udf(df["w"]).over(w).alias("min_w"), + ) + + expected1 = df.select( + df["id"], + df["v"], + sf.mean(df["v"]).over(w).alias("mean_v"), + sf.max(df["v"]).over(w).alias("max_v"), + sf.min(df["w"]).over(w).alias("min_w"), + ) + + assert_frame_equal(expected1.toPandas(), result1.toPandas()) + def test_replace_existing(self): df = self.data w = self.unbounded_window From c2f1b659b66ce1896a5dae8b0853b3941fc50f7f Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Wed, 26 Nov 2025 11:39:28 -0800 Subject: [PATCH 3/3] fix: format --- python/pyspark/sql/pandas/serializers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 6327332118de6..f757ba4f696f8 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1224,7 +1224,11 @@ def load_stream(self, stream): yield ( [ self.arrow_to_pandas(c, i) - for i, c in enumerate(pa.Table.from_batches(ArrowStreamSerializer.load_stream(self, stream)).itercolumns()) + for i, c in enumerate( + pa.Table.from_batches( + ArrowStreamSerializer.load_stream(self, stream) + ).itercolumns() + ) ] )