Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d1e4f67
refactor: merge GroupPandasIterUDFSerializer to GroupPandasUDFSerializer
Yicong-Huang Nov 13, 2025
6895207
fix: format
Yicong-Huang Nov 13, 2025
c6ad502
fix: format
Yicong-Huang Nov 13, 2025
35ce001
feat: redesign the wrapper and serializer
Yicong-Huang Nov 14, 2025
37bcda4
fix: format
Yicong-Huang Nov 14, 2025
f216df3
fix: handle comments
Yicong-Huang Nov 14, 2025
6410866
chore: clean up
Yicong-Huang Nov 14, 2025
0cc9432
fix: change serializer input to list
Yicong-Huang Nov 20, 2025
b3dfd06
fix: align with GroupArrowUDFSerializer
Yicong-Huang Nov 20, 2025
c7317cf
fix: format
Yicong-Huang Nov 20, 2025
5eab53c
fix: wrong indentation
Yicong-Huang Nov 20, 2025
3a45f53
fix: move order
Yicong-Huang Nov 20, 2025
91fb21b
fix: use two serializers
Yicong-Huang Nov 24, 2025
89faad8
fix: format
Yicong-Huang Nov 24, 2025
ccaa9b4
fix: format
Yicong-Huang Nov 24, 2025
cf96d3b
fix: format
Yicong-Huang Nov 25, 2025
178fbce
fix: remove changes with multi UDF, SQL_GROUPED_MAP_PANDAS_UDF and S…
Yicong-Huang Nov 25, 2025
57fe551
fix: remove ArrowStreamAggPandasUDFSerializer
Yicong-Huang Nov 25, 2025
ff64faa
fix: remove mapper block
Yicong-Huang Nov 25, 2025
ea9dac1
fix: another implementation to seperate list and iterator expectation
Yicong-Huang Nov 26, 2025
1bce586
Merge branch 'apache:master' into SPARK-54316/refactor/consolidate-pa…
Yicong-Huang Nov 26, 2025
2614bc3
Merge branch 'master' into SPARK-54316/refactor/consolidate-pandas-it…
Yicong-Huang Nov 30, 2025
8e4173d
fix: simplify GroupPandasUDFSerializer
Yicong-Huang Nov 30, 2025
ce9fe22
fix: simplify serialzier use baseclass
Yicong-Huang Nov 30, 2025
edace86
fix: simplify wrappers
Yicong-Huang Nov 30, 2025
484962a
fix: comments
Yicong-Huang Nov 30, 2025
11cdc2b
fix: revert unrelated changes
Yicong-Huang Nov 30, 2025
22890b1
refactor: consolidate GroupPandasUDFSerializer to support both SQL_GR…
Yicong-Huang Dec 1, 2025
e89139f
fix: simplify
Yicong-Huang Dec 1, 2025
81b3ab8
fix: batch concatination and key type
Yicong-Huang Dec 1, 2025
86e71cf
fix: format
Yicong-Huang Dec 1, 2025
172ca80
fix: simplify and format
Yicong-Huang Dec 1, 2025
86e73b8
refactor: consolidate GroupPandasUDFSerializer to support both SQL_GR…
Yicong-Huang Dec 1, 2025
435e916
Merge branch 'master' into SPARK-54316/refactor/consolidate-pandas-it…
Yicong-Huang Dec 3, 2025
b33c44e
fix: missing import
Yicong-Huang Dec 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 11 additions & 72 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,72 +1317,16 @@ def __init__(

def load_stream(self, stream):
"""
Deserialize Grouped ArrowRecordBatches to a tuple of Arrow tables and yield as a
list of pandas.Series.
"""
import pyarrow as pa

dataframes_in_group = None

while dataframes_in_group is None or dataframes_in_group > 0:
dataframes_in_group = read_int(stream)

if dataframes_in_group == 1:
batches = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
yield (
[
self.arrow_to_pandas(c, i)
for i, c in enumerate(pa.Table.from_batches(batches).itercolumns())
]
)

elif dataframes_in_group != 0:
raise PySparkValueError(
errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP",
messageParameters={"dataframes_in_group": str(dataframes_in_group)},
)

def __repr__(self):
return "GroupPandasUDFSerializer"


class GroupPandasIterUDFSerializer(ArrowStreamPandasUDFSerializer):
"""
Serializer for grouped map Pandas iterator UDFs.

Loads grouped data as pandas.Series and serializes results from iterator UDFs.
Flattens the (dataframes_generator, arrow_type) tuple by iterating over the generator.
"""

def __init__(
self,
timezone,
safecheck,
assign_cols_by_name,
int_to_decimal_coercion_enabled,
):
super(GroupPandasIterUDFSerializer, self).__init__(
timezone=timezone,
safecheck=safecheck,
assign_cols_by_name=assign_cols_by_name,
df_for_struct=False,
struct_in_pandas="dict",
ndarray_as_list=False,
arrow_cast=True,
input_types=None,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)

def load_stream(self, stream):
"""
Deserialize Grouped ArrowRecordBatches and yield a generator of pandas.Series lists
(one list per batch), allowing the iterator UDF to process data batch-by-batch.
Deserialize Grouped ArrowRecordBatches and yield as Iterator[Iterator[pd.Series]].
Each outer iterator element represents a group, containing an iterator of Series lists
(one list per batch).
"""
import pyarrow as pa

def process_group(batches: "Iterator[pa.RecordBatch]"):
# Convert each Arrow batch to pandas Series list on-demand, yielding one list per batch
for batch in batches:
# The batch from ArrowStreamSerializer is already flattened (no struct wrapper)
series = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns())
Expand Down Expand Up @@ -1411,21 +1355,16 @@ def process_group(batches: "Iterator[pa.RecordBatch]"):

def dump_stream(self, iterator, stream):
"""
Flatten the (dataframes_generator, arrow_type) tuples by iterating over each generator.
This allows the iterator UDF to stream results without materializing all DataFrames.
Flatten the Iterator[Iterator[[(df, arrow_type)]]] returned by func.
The mapper returns Iterator[[(df, arrow_type)]], so we flatten one level
to match the parent's expected format Iterator[[(df, arrow_type)]].
"""
# Flatten: (dataframes_generator, arrow_type) -> (df, arrow_type), (df, arrow_type), ...
flattened_iter = (
(df, arrow_type) for dataframes_gen, arrow_type in iterator for df in dataframes_gen
)

# Convert each (df, arrow_type) to the format expected by parent's dump_stream
series_iter = ([(df, arrow_type)] for df, arrow_type in flattened_iter)

super(GroupPandasIterUDFSerializer, self).dump_stream(series_iter, stream)
# Flatten: Iterator[Iterator[[(df, arrow_type)]]] -> Iterator[[(df, arrow_type)]]
flattened_iter = (batch for generator in iterator for batch in generator)
super(GroupPandasUDFSerializer, self).dump_stream(flattened_iter, stream)

def __repr__(self):
return "GroupPandasIterUDFSerializer"
return "GroupPandasUDFSerializer"


class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer):
Expand Down
100 changes: 52 additions & 48 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,11 @@
ArrowStreamPandasUDFSerializer,
ArrowStreamPandasUDTFSerializer,
GroupArrowUDFSerializer,
GroupPandasUDFSerializer,
CogroupArrowUDFSerializer,
CogroupPandasUDFSerializer,
ArrowStreamUDFSerializer,
ApplyInPandasWithStateSerializer,
GroupPandasIterUDFSerializer,
GroupPandasUDFSerializer,
TransformWithStateInPandasSerializer,
TransformWithStateInPandasInitStateSerializer,
TransformWithStateInPySparkRowSerializer,
Expand Down Expand Up @@ -736,43 +735,59 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
_use_large_var_types = use_large_var_types(runner_conf)
_assign_cols_by_name = assign_cols_by_name(runner_conf)

def wrapped(key_series, value_series):
def wrapped(key_series, value_batches):
import pandas as pd

# Convert value_batches (Iterator[list[pd.Series]]) to a single DataFrame
# Each value_series is a list of Series (one per column) for one batch
# Concatenate Series within each batch (axis=1), then concatenate batches (axis=0)
value_dataframes = []
for value_series in value_batches:
value_dataframes.append(pd.concat(value_series, axis=1))

value_df = pd.concat(value_dataframes, axis=0) if value_dataframes else pd.DataFrame()

if len(argspec.args) == 1:
result = f(pd.concat(value_series, axis=1))
result = f(value_df)
elif len(argspec.args) == 2:
key = tuple(s[0] for s in key_series)
result = f(key, pd.concat(value_series, axis=1))
# Extract key from pandas Series, preserving numpy types
key = tuple(s.iloc[0] for s in key_series)
result = f(key, value_df)

verify_pandas_result(
result, return_type, _assign_cols_by_name, truncate_return_schema=False
)

return result
yield result

arrow_return_type = to_arrow_type(return_type, _use_large_var_types)
return lambda k, v: [(wrapped(k, v), arrow_return_type)]

def flatten_wrapper(k, v):
# Return Iterator[[(df, arrow_type)]] directly
for df in wrapped(k, v):
yield [(df, arrow_return_type)]

return flatten_wrapper


def wrap_grouped_map_pandas_iter_udf(f, return_type, argspec, runner_conf):
_use_large_var_types = use_large_var_types(runner_conf)
_assign_cols_by_name = assign_cols_by_name(runner_conf)

def wrapped(key_series_list, value_series_gen):
def wrapped(key_series, value_batches):
import pandas as pd

# value_series_gen is a generator that yields multiple lists of Series (one per batch)
# value_batches is an Iterator[list[pd.Series]] (one list per batch)
# Convert each list of Series into a DataFrame
def dataframe_iter():
for value_series in value_series_gen:
for value_series in value_batches:
yield pd.concat(value_series, axis=1)

# Extract key from the first batch
if len(argspec.args) == 1:
result = f(dataframe_iter())
elif len(argspec.args) == 2:
# key_series_list is a list of Series for the key columns from the first batch
key = tuple(s[0] for s in key_series_list)
# Extract key from pandas Series, preserving numpy types
key = tuple(s.iloc[0] for s in key_series)
result = f(key, dataframe_iter())

def verify_element(df):
Expand All @@ -784,7 +799,13 @@ def verify_element(df):
yield from map(verify_element, result)

arrow_return_type = to_arrow_type(return_type, _use_large_var_types)
return lambda k, v: (wrapped(k, v), arrow_return_type)

def flatten_wrapper(k, v):
# Return Iterator[[(df, arrow_type)]] directly
for df in wrapped(k, v):
yield [(df, arrow_return_type)]

return flatten_wrapper


def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf):
Expand Down Expand Up @@ -2755,14 +2776,13 @@ def read_udfs(pickleSer, infile, eval_type):
ser = ArrowStreamAggPandasUDFSerializer(
timezone, safecheck, _assign_cols_by_name, int_to_decimal_coercion_enabled
)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
elif (
eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
or eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
):
ser = GroupPandasUDFSerializer(
timezone, safecheck, _assign_cols_by_name, int_to_decimal_coercion_enabled
)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
ser = GroupPandasIterUDFSerializer(
timezone, safecheck, _assign_cols_by_name, int_to_decimal_coercion_enabled
)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
ser = CogroupArrowUDFSerializer(_assign_cols_by_name)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
Expand Down Expand Up @@ -3003,26 +3023,12 @@ def extract_key_value_indexes(grouped_arg_offsets):
idx += offsets_len
return parsed

if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
# We assume there is only one UDF here because grouped map doesn't
# support combining multiple UDFs.
assert num_udfs == 1

# See FlatMapGroupsInPandasExec for how arg_offsets are used to
# distinguish between grouping attributes and data attributes
arg_offsets, f = read_single_udf(
pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
)
parsed_offsets = extract_key_value_indexes(arg_offsets)

# Create function like this:
# mapper a: f([a[0]], [a[0], a[1]])
def mapper(a):
keys = [a[o] for o in parsed_offsets[0][0]]
vals = [a[o] for o in parsed_offsets[0][1]]
return f(keys, vals)
if (
eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
or eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
):
import pyarrow as pa

elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
# We assume there is only one UDF here because grouped map doesn't
# support combining multiple UDFs.
assert num_udfs == 1
Expand All @@ -3034,23 +3040,21 @@ def mapper(a):
)
parsed_offsets = extract_key_value_indexes(arg_offsets)

# Create mapper similar to Arrow iterator:
# `a` is an iterator of Series lists (one list per batch, containing all columns)
# Materialize first batch to get keys, then create generator for value batches
def mapper(a):
import itertools

series_iter = iter(a)
def mapper(series_iter):
# Need to materialize the first series list to get the keys
first_series_list = next(series_iter)

keys = [first_series_list[o] for o in parsed_offsets[0][0]]
# Extract key Series from the first batch
key_series = [first_series_list[o] for o in parsed_offsets[0][0]]

# Create generator for value Series lists (one list per batch)
value_series_gen = (
[series_list[o] for o in parsed_offsets[0][1]]
for series_list in itertools.chain((first_series_list,), series_iter)
)

return f(keys, value_series_gen)
# Flatten one level: yield from wrapper to return Iterator[[(df, arrow_type)]]
yield from f(key_series, value_series_gen)

elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF:
# We assume there is only one UDF here because grouped map doesn't
Expand Down