diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py index 74a81be37f80..844c7f111db4 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py @@ -1153,13 +1153,14 @@ def test_iterator_grouped_agg_partial_consumption(self): # Create a dataset with multiple batches per group # Use small batch size to ensure multiple batches per group + # Use same value for all data points to avoid ordering issues with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 2}): df = self.spark.createDataFrame( - [(1, 1.0), (1, 2.0), (1, 3.0), (1, 4.0), (2, 5.0), (2, 6.0)], ("id", "v") + [(1, 1.0), (1, 1.0), (1, 1.0), (1, 1.0), (2, 1.0), (2, 1.0)], ("id", "v") ) - @arrow_udf("double") - def arrow_sum_partial(it: Iterator[pa.Array]) -> float: + @arrow_udf("struct") + def arrow_count_sum_partial(it: Iterator[pa.Array]) -> dict: # Only consume first two batches, then return # This tests that partial consumption works correctly total = 0.0 @@ -1171,32 +1172,44 @@ def arrow_sum_partial(it: Iterator[pa.Array]) -> float: else: # Stop early - partial consumption break - return total / count if count > 0 else 0.0 + return {"count": count, "sum": total} - result = df.groupby("id").agg(arrow_sum_partial(df["v"]).alias("mean")).sort("id") + result = ( + df.groupby("id").agg(arrow_count_sum_partial(df["v"]).alias("result")).sort("id") + ) # Verify results are correct for partial consumption # With batch size = 2: # Group 1 (id=1): 4 values in 2 batches -> processes both batches - # Batch 1: [1.0, 2.0], Batch 2: [3.0, 4.0] - # Result: (1.0+2.0+3.0+4.0)/4 = 2.5 + # Batch 1: [1.0, 1.0], Batch 2: [1.0, 1.0] + # Result: count=4, sum=4.0 # Group 2 (id=2): 2 values in 1 batch -> processes 1 batch (only 1 batch available) - # Batch 1: [5.0, 6.0] - # Result: (5.0+6.0)/2 = 5.5 + # Batch 1: [1.0, 1.0] + # Result: count=2, sum=2.0 actual = result.collect() self.assertEqual(len(actual), 2, "Should have results for both groups") # Verify both groups were processed correctly # Group 1: processes 2 batches (all available) group1_result = next(row for row in actual if row["id"] == 1) + self.assertEqual( + group1_result["result"]["count"], + 4, + msg="Group 1 should process 4 values (2 batches)", + ) self.assertAlmostEqual( - group1_result["mean"], 2.5, places=5, msg="Group 1 should process 2 batches" + group1_result["result"]["sum"], 4.0, places=5, msg="Group 1 should sum to 4.0" ) # Group 2: processes 1 batch (only batch available) group2_result = next(row for row in actual if row["id"] == 2) + self.assertEqual( + group2_result["result"]["count"], + 2, + msg="Group 2 should process 2 values (1 batch)", + ) self.assertAlmostEqual( - group2_result["mean"], 5.5, places=5, msg="Group 2 should process 1 batch" + group2_result["result"]["sum"], 2.0, places=5, msg="Group 2 should sum to 2.0" )