diff --git a/python/pyspark/sql/tests/test_udf_profiler.py b/python/pyspark/sql/tests/test_udf_profiler.py index e6a7bf40b945..4d565ecfd939 100644 --- a/python/pyspark/sql/tests/test_udf_profiler.py +++ b/python/pyspark/sql/tests/test_udf_profiler.py @@ -138,6 +138,23 @@ def iter_to_iter(iter: Iterator[pa.Array]) -> Iterator[pa.Array]: self.spark.range(10).select(iter_to_iter("id")).collect() + def exec_arrow_udf_grouped_agg_iter(self): + import pyarrow as pa + + @arrow_udf("double") + def arrow_mean_iter(it: Iterator[pa.Array]) -> float: + sum_val = 0.0 + cnt = 0 + for v in it: + sum_val += pa.compute.sum(v).as_py() + cnt += len(v) + return sum_val / cnt if cnt > 0 else 0.0 + + df = self.spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v") + ) + df.groupby("id").agg(arrow_mean_iter(df["v"])).collect() + # Unsupported def exec_map(self): import pandas as pd @@ -169,6 +186,15 @@ def test_unsupported(self): "Profiling UDFs with iterators input/output is not supported" in str(user_warns[0]) ) + with warnings.catch_warnings(record=True) as warns: + warnings.simplefilter("always") + self.exec_arrow_udf_grouped_agg_iter() + user_warns = [warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Profiling UDFs with iterators input/output is not supported" in str(user_warns[0]) + ) + with warnings.catch_warnings(record=True) as warns: warnings.simplefilter("always") self.exec_map() @@ -486,6 +512,31 @@ def min_udf(v: pa.Array) -> float: for id in self.profile_results: self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=2) + @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) + def test_perf_profiler_arrow_udf_grouped_agg_iter(self): + import pyarrow as pa + from typing import Iterator + + @arrow_udf("double") + def arrow_mean_iter(it: Iterator[pa.Array]) -> float: + sum_val = 0.0 + cnt = 0 + for v in it: + sum_val += pa.compute.sum(v).as_py() + cnt += len(v) + return sum_val / cnt if cnt > 0 else 0.0 + + with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}): + df = self.spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v") + ) + df.groupBy(df.id).agg(arrow_mean_iter(df["v"])).show() + + self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys())) + + for id in self.profile_results: + self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=2) + @unittest.skipIf( not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index c7471d19f7d6..75bcb66efdb8 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -438,6 +438,7 @@ def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column: PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, + PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF, ]: warnings.warn( "Profiling UDFs with iterators input/output is not supported.",