Skip to content

Commit 58ba7db

Browse files
Yicong-Huangxu20160924
authored andcommitted
[SPARK-53615][FOLLOWUP][PYTHON][TEST] Fix test case for arrow grouped agg UDF partial consumption
### What changes were proposed in this pull request? Fix test case `test_iterator_grouped_agg_partial_consumption` to use count and sum instead of mean for testing partial consumption. Use the same value for all data points to avoid ordering issues. ### Why are the changes needed? Fixes test flakiness by ensuring test data points have the same value and using count/sum metrics that properly validate partial consumption behavior, making the test robust against ordering variations. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Ran existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#53372 from Yicong-Huang/SPARK-53615/fix/fix-test-partial-consumption-ordering. Authored-by: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 00c09db commit 58ba7db

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,13 +1153,14 @@ def test_iterator_grouped_agg_partial_consumption(self):
11531153

11541154
# Create a dataset with multiple batches per group
11551155
# Use small batch size to ensure multiple batches per group
1156+
# Use same value for all data points to avoid ordering issues
11561157
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 2}):
11571158
df = self.spark.createDataFrame(
1158-
[(1, 1.0), (1, 2.0), (1, 3.0), (1, 4.0), (2, 5.0), (2, 6.0)], ("id", "v")
1159+
[(1, 1.0), (1, 1.0), (1, 1.0), (1, 1.0), (2, 1.0), (2, 1.0)], ("id", "v")
11591160
)
11601161

1161-
@arrow_udf("double")
1162-
def arrow_sum_partial(it: Iterator[pa.Array]) -> float:
1162+
@arrow_udf("struct<count:bigint,sum:double>")
1163+
def arrow_count_sum_partial(it: Iterator[pa.Array]) -> dict:
11631164
# Only consume first two batches, then return
11641165
# This tests that partial consumption works correctly
11651166
total = 0.0
@@ -1171,32 +1172,44 @@ def arrow_sum_partial(it: Iterator[pa.Array]) -> float:
11711172
else:
11721173
# Stop early - partial consumption
11731174
break
1174-
return total / count if count > 0 else 0.0
1175+
return {"count": count, "sum": total}
11751176

1176-
result = df.groupby("id").agg(arrow_sum_partial(df["v"]).alias("mean")).sort("id")
1177+
result = (
1178+
df.groupby("id").agg(arrow_count_sum_partial(df["v"]).alias("result")).sort("id")
1179+
)
11771180

11781181
# Verify results are correct for partial consumption
11791182
# With batch size = 2:
11801183
# Group 1 (id=1): 4 values in 2 batches -> processes both batches
1181-
# Batch 1: [1.0, 2.0], Batch 2: [3.0, 4.0]
1182-
# Result: (1.0+2.0+3.0+4.0)/4 = 2.5
1184+
# Batch 1: [1.0, 1.0], Batch 2: [1.0, 1.0]
1185+
# Result: count=4, sum=4.0
11831186
# Group 2 (id=2): 2 values in 1 batch -> processes 1 batch (only 1 batch available)
1184-
# Batch 1: [5.0, 6.0]
1185-
# Result: (5.0+6.0)/2 = 5.5
1187+
# Batch 1: [1.0, 1.0]
1188+
# Result: count=2, sum=2.0
11861189
actual = result.collect()
11871190
self.assertEqual(len(actual), 2, "Should have results for both groups")
11881191

11891192
# Verify both groups were processed correctly
11901193
# Group 1: processes 2 batches (all available)
11911194
group1_result = next(row for row in actual if row["id"] == 1)
1195+
self.assertEqual(
1196+
group1_result["result"]["count"],
1197+
4,
1198+
msg="Group 1 should process 4 values (2 batches)",
1199+
)
11921200
self.assertAlmostEqual(
1193-
group1_result["mean"], 2.5, places=5, msg="Group 1 should process 2 batches"
1201+
group1_result["result"]["sum"], 4.0, places=5, msg="Group 1 should sum to 4.0"
11941202
)
11951203

11961204
# Group 2: processes 1 batch (only batch available)
11971205
group2_result = next(row for row in actual if row["id"] == 2)
1206+
self.assertEqual(
1207+
group2_result["result"]["count"],
1208+
2,
1209+
msg="Group 2 should process 2 values (1 batch)",
1210+
)
11981211
self.assertAlmostEqual(
1199-
group2_result["mean"], 5.5, places=5, msg="Group 2 should process 1 batch"
1212+
group2_result["result"]["sum"], 2.0, places=5, msg="Group 2 should sum to 2.0"
12001213
)
12011214

12021215

0 commit comments

Comments
 (0)