diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 12a1f3c288b49..f757ba4f696f8 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1185,6 +1185,64 @@ def __repr__(self): return "ArrowStreamAggArrowUDFSerializer" +# Serializer for SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF +class ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer): + def __init__( + self, + timezone, + safecheck, + assign_cols_by_name, + int_to_decimal_coercion_enabled, + ): + super(ArrowStreamAggPandasUDFSerializer, self).__init__( + timezone=timezone, + safecheck=safecheck, + assign_cols_by_name=False, + df_for_struct=False, + struct_in_pandas="dict", + ndarray_as_list=False, + arrow_cast=True, + input_types=None, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + ) + self._timezone = timezone + self._safecheck = safecheck + self._assign_cols_by_name = assign_cols_by_name + + def load_stream(self, stream): + """ + Deserialize Grouped ArrowRecordBatches and yield as a list of pandas.Series. + """ + import pyarrow as pa + + dataframes_in_group = None + + while dataframes_in_group is None or dataframes_in_group > 0: + dataframes_in_group = read_int(stream) + + if dataframes_in_group == 1: + yield ( + [ + self.arrow_to_pandas(c, i) + for i, c in enumerate( + pa.Table.from_batches( + ArrowStreamSerializer.load_stream(self, stream) + ).itercolumns() + ) + ] + ) + + elif dataframes_in_group != 0: + raise PySparkValueError( + errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", + messageParameters={"dataframes_in_group": str(dataframes_in_group)}, + ) + + def __repr__(self): + return "ArrowStreamAggPandasUDFSerializer" + + +# Serializer for SQL_GROUPED_MAP_PANDAS_UDF, SQL_GROUPED_MAP_PANDAS_ITER_UDF class GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer): def __init__( self, diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py index 6e1cbdaf73cff..7cfcb29f50c17 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py @@ -190,6 +190,32 @@ def test_multiple_udfs(self): assert_frame_equal(expected1.toPandas(), result1.toPandas()) + def test_multiple_udfs_in_single_projection(self): + """ + Test multiple window aggregate pandas UDFs in a single select/projection. + """ + df = self.data + w = self.unbounded_window + + # Use select() with multiple window UDFs in the same projection + result1 = df.select( + df["id"], + df["v"], + self.pandas_agg_mean_udf(df["v"]).over(w).alias("mean_v"), + self.pandas_agg_max_udf(df["v"]).over(w).alias("max_v"), + self.pandas_agg_min_udf(df["w"]).over(w).alias("min_w"), + ) + + expected1 = df.select( + df["id"], + df["v"], + sf.mean(df["v"]).over(w).alias("mean_v"), + sf.max(df["v"]).over(w).alias("max_v"), + sf.min(df["w"]).over(w).alias("min_w"), + ) + + assert_frame_equal(expected1.toPandas(), result1.toPandas()) + def test_replace_existing(self): df = self.data w = self.unbounded_window diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 94e3b2728d08f..96aac6083bf27 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -67,6 +67,7 @@ TransformWithStateInPySparkRowSerializer, TransformWithStateInPySparkRowInitStateSerializer, ArrowStreamArrowUDFSerializer, + ArrowStreamAggPandasUDFSerializer, ArrowStreamAggArrowUDFSerializer, ArrowBatchUDFSerializer, ArrowStreamUDTFSerializer, @@ -2721,10 +2722,13 @@ def read_udfs(pickleSer, infile, eval_type): ): ser = ArrowStreamAggArrowUDFSerializer(timezone, True, _assign_cols_by_name, True) elif eval_type in ( - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, ): + ser = ArrowStreamAggPandasUDFSerializer( + timezone, safecheck, _assign_cols_by_name, int_to_decimal_coercion_enabled + ) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: ser = GroupPandasUDFSerializer( timezone, safecheck, _assign_cols_by_name, int_to_decimal_coercion_enabled )