Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,13 +1153,14 @@ def test_iterator_grouped_agg_partial_consumption(self):

# Create a dataset with multiple batches per group
# Use small batch size to ensure multiple batches per group
# Use same value for all data points to avoid ordering issues
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 2}):
df = self.spark.createDataFrame(
[(1, 1.0), (1, 2.0), (1, 3.0), (1, 4.0), (2, 5.0), (2, 6.0)], ("id", "v")
[(1, 1.0), (1, 1.0), (1, 1.0), (1, 1.0), (2, 1.0), (2, 1.0)], ("id", "v")
)

@arrow_udf("double")
def arrow_sum_partial(it: Iterator[pa.Array]) -> float:
@arrow_udf("struct<count:bigint,sum:double>")
def arrow_count_sum_partial(it: Iterator[pa.Array]) -> dict:
# Only consume first two batches, then return
# This tests that partial consumption works correctly
total = 0.0
Expand All @@ -1171,32 +1172,44 @@ def arrow_sum_partial(it: Iterator[pa.Array]) -> float:
else:
# Stop early - partial consumption
break
return total / count if count > 0 else 0.0
return {"count": count, "sum": total}

result = df.groupby("id").agg(arrow_sum_partial(df["v"]).alias("mean")).sort("id")
result = (
df.groupby("id").agg(arrow_count_sum_partial(df["v"]).alias("result")).sort("id")
)

# Verify results are correct for partial consumption
# With batch size = 2:
# Group 1 (id=1): 4 values in 2 batches -> processes both batches
# Batch 1: [1.0, 2.0], Batch 2: [3.0, 4.0]
# Result: (1.0+2.0+3.0+4.0)/4 = 2.5
# Batch 1: [1.0, 1.0], Batch 2: [1.0, 1.0]
# Result: count=4, sum=4.0
# Group 2 (id=2): 2 values in 1 batch -> processes 1 batch (only 1 batch available)
# Batch 1: [5.0, 6.0]
# Result: (5.0+6.0)/2 = 5.5
# Batch 1: [1.0, 1.0]
# Result: count=2, sum=2.0
actual = result.collect()
self.assertEqual(len(actual), 2, "Should have results for both groups")

# Verify both groups were processed correctly
# Group 1: processes 2 batches (all available)
group1_result = next(row for row in actual if row["id"] == 1)
self.assertEqual(
group1_result["result"]["count"],
4,
msg="Group 1 should process 4 values (2 batches)",
)
self.assertAlmostEqual(
group1_result["mean"], 2.5, places=5, msg="Group 1 should process 2 batches"
group1_result["result"]["sum"], 4.0, places=5, msg="Group 1 should sum to 4.0"
)

# Group 2: processes 1 batch (only batch available)
group2_result = next(row for row in actual if row["id"] == 2)
self.assertEqual(
group2_result["result"]["count"],
2,
msg="Group 2 should process 2 values (1 batch)",
)
self.assertAlmostEqual(
group2_result["mean"], 5.5, places=5, msg="Group 2 should process 1 batch"
group2_result["result"]["sum"], 2.0, places=5, msg="Group 2 should sum to 2.0"
)


Expand Down