diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 9cef600dc646d..667af40c36bc9 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1317,72 +1317,16 @@ def __init__( def load_stream(self, stream): """ - Deserialize Grouped ArrowRecordBatches to a tuple of Arrow tables and yield as a - list of pandas.Series. - """ - import pyarrow as pa - - dataframes_in_group = None - - while dataframes_in_group is None or dataframes_in_group > 0: - dataframes_in_group = read_int(stream) - - if dataframes_in_group == 1: - batches = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)] - yield ( - [ - self.arrow_to_pandas(c, i) - for i, c in enumerate(pa.Table.from_batches(batches).itercolumns()) - ] - ) - - elif dataframes_in_group != 0: - raise PySparkValueError( - errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", - messageParameters={"dataframes_in_group": str(dataframes_in_group)}, - ) - - def __repr__(self): - return "GroupPandasUDFSerializer" - - -class GroupPandasIterUDFSerializer(ArrowStreamPandasUDFSerializer): - """ - Serializer for grouped map Pandas iterator UDFs. - - Loads grouped data as pandas.Series and serializes results from iterator UDFs. - Flattens the (dataframes_generator, arrow_type) tuple by iterating over the generator. - """ - - def __init__( - self, - timezone, - safecheck, - assign_cols_by_name, - int_to_decimal_coercion_enabled, - ): - super(GroupPandasIterUDFSerializer, self).__init__( - timezone=timezone, - safecheck=safecheck, - assign_cols_by_name=assign_cols_by_name, - df_for_struct=False, - struct_in_pandas="dict", - ndarray_as_list=False, - arrow_cast=True, - input_types=None, - int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, - ) - - def load_stream(self, stream): - """ - Deserialize Grouped ArrowRecordBatches and yield a generator of pandas.Series lists - (one list per batch), allowing the iterator UDF to process data batch-by-batch. + Deserialize Grouped ArrowRecordBatches and yield as Iterator[Iterator[pd.Series]]. + Each outer iterator element represents a group, containing an iterator of Series lists + (one list per batch). """ import pyarrow as pa def process_group(batches: "Iterator[pa.RecordBatch]"): # Convert each Arrow batch to pandas Series list on-demand, yielding one list per batch for batch in batches: + # The batch from ArrowStreamSerializer is already flattened (no struct wrapper) series = [ self.arrow_to_pandas(c, i) for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns()) @@ -1411,21 +1355,16 @@ def process_group(batches: "Iterator[pa.RecordBatch]"): def dump_stream(self, iterator, stream): """ - Flatten the (dataframes_generator, arrow_type) tuples by iterating over each generator. - This allows the iterator UDF to stream results without materializing all DataFrames. + Flatten the Iterator[Iterator[[(df, arrow_type)]]] returned by func. + The mapper returns Iterator[[(df, arrow_type)]], so we flatten one level + to match the parent's expected format Iterator[[(df, arrow_type)]]. """ - # Flatten: (dataframes_generator, arrow_type) -> (df, arrow_type), (df, arrow_type), ... - flattened_iter = ( - (df, arrow_type) for dataframes_gen, arrow_type in iterator for df in dataframes_gen - ) - - # Convert each (df, arrow_type) to the format expected by parent's dump_stream - series_iter = ([(df, arrow_type)] for df, arrow_type in flattened_iter) - - super(GroupPandasIterUDFSerializer, self).dump_stream(series_iter, stream) + # Flatten: Iterator[Iterator[[(df, arrow_type)]]] -> Iterator[[(df, arrow_type)]] + flattened_iter = (batch for generator in iterator for batch in generator) + super(GroupPandasUDFSerializer, self).dump_stream(flattened_iter, stream) def __repr__(self): - return "GroupPandasIterUDFSerializer" + return "GroupPandasUDFSerializer" class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 6a5bde5d2acc2..0c3fadbc6a9fb 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -54,12 +54,11 @@ ArrowStreamPandasUDFSerializer, ArrowStreamPandasUDTFSerializer, GroupArrowUDFSerializer, + GroupPandasUDFSerializer, CogroupArrowUDFSerializer, CogroupPandasUDFSerializer, ArrowStreamUDFSerializer, ApplyInPandasWithStateSerializer, - GroupPandasIterUDFSerializer, - GroupPandasUDFSerializer, TransformWithStateInPandasSerializer, TransformWithStateInPandasInitStateSerializer, TransformWithStateInPySparkRowSerializer, @@ -736,43 +735,59 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): _use_large_var_types = use_large_var_types(runner_conf) _assign_cols_by_name = assign_cols_by_name(runner_conf) - def wrapped(key_series, value_series): + def wrapped(key_series, value_batches): import pandas as pd + # Convert value_batches (Iterator[list[pd.Series]]) to a single DataFrame + # Each value_series is a list of Series (one per column) for one batch + # Concatenate Series within each batch (axis=1), then concatenate batches (axis=0) + value_dataframes = [] + for value_series in value_batches: + value_dataframes.append(pd.concat(value_series, axis=1)) + + value_df = pd.concat(value_dataframes, axis=0) if value_dataframes else pd.DataFrame() + if len(argspec.args) == 1: - result = f(pd.concat(value_series, axis=1)) + result = f(value_df) elif len(argspec.args) == 2: - key = tuple(s[0] for s in key_series) - result = f(key, pd.concat(value_series, axis=1)) + # Extract key from pandas Series, preserving numpy types + key = tuple(s.iloc[0] for s in key_series) + result = f(key, value_df) + verify_pandas_result( result, return_type, _assign_cols_by_name, truncate_return_schema=False ) - return result + yield result arrow_return_type = to_arrow_type(return_type, _use_large_var_types) - return lambda k, v: [(wrapped(k, v), arrow_return_type)] + + def flatten_wrapper(k, v): + # Return Iterator[[(df, arrow_type)]] directly + for df in wrapped(k, v): + yield [(df, arrow_return_type)] + + return flatten_wrapper def wrap_grouped_map_pandas_iter_udf(f, return_type, argspec, runner_conf): _use_large_var_types = use_large_var_types(runner_conf) _assign_cols_by_name = assign_cols_by_name(runner_conf) - def wrapped(key_series_list, value_series_gen): + def wrapped(key_series, value_batches): import pandas as pd - # value_series_gen is a generator that yields multiple lists of Series (one per batch) + # value_batches is an Iterator[list[pd.Series]] (one list per batch) # Convert each list of Series into a DataFrame def dataframe_iter(): - for value_series in value_series_gen: + for value_series in value_batches: yield pd.concat(value_series, axis=1) - # Extract key from the first batch if len(argspec.args) == 1: result = f(dataframe_iter()) elif len(argspec.args) == 2: - # key_series_list is a list of Series for the key columns from the first batch - key = tuple(s[0] for s in key_series_list) + # Extract key from pandas Series, preserving numpy types + key = tuple(s.iloc[0] for s in key_series) result = f(key, dataframe_iter()) def verify_element(df): @@ -784,7 +799,13 @@ def verify_element(df): yield from map(verify_element, result) arrow_return_type = to_arrow_type(return_type, _use_large_var_types) - return lambda k, v: (wrapped(k, v), arrow_return_type) + + def flatten_wrapper(k, v): + # Return Iterator[[(df, arrow_type)]] directly + for df in wrapped(k, v): + yield [(df, arrow_return_type)] + + return flatten_wrapper def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf): @@ -2755,14 +2776,13 @@ def read_udfs(pickleSer, infile, eval_type): ser = ArrowStreamAggPandasUDFSerializer( timezone, safecheck, _assign_cols_by_name, int_to_decimal_coercion_enabled ) - elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + elif ( + eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF + or eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF + ): ser = GroupPandasUDFSerializer( timezone, safecheck, _assign_cols_by_name, int_to_decimal_coercion_enabled ) - elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF: - ser = GroupPandasIterUDFSerializer( - timezone, safecheck, _assign_cols_by_name, int_to_decimal_coercion_enabled - ) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: ser = CogroupArrowUDFSerializer(_assign_cols_by_name) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: @@ -3003,26 +3023,12 @@ def extract_key_value_indexes(grouped_arg_offsets): idx += offsets_len return parsed - if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: - # We assume there is only one UDF here because grouped map doesn't - # support combining multiple UDFs. - assert num_udfs == 1 - - # See FlatMapGroupsInPandasExec for how arg_offsets are used to - # distinguish between grouping attributes and data attributes - arg_offsets, f = read_single_udf( - pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler - ) - parsed_offsets = extract_key_value_indexes(arg_offsets) - - # Create function like this: - # mapper a: f([a[0]], [a[0], a[1]]) - def mapper(a): - keys = [a[o] for o in parsed_offsets[0][0]] - vals = [a[o] for o in parsed_offsets[0][1]] - return f(keys, vals) + if ( + eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF + or eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF + ): + import pyarrow as pa - elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF: # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 @@ -3034,23 +3040,21 @@ def mapper(a): ) parsed_offsets = extract_key_value_indexes(arg_offsets) - # Create mapper similar to Arrow iterator: - # `a` is an iterator of Series lists (one list per batch, containing all columns) - # Materialize first batch to get keys, then create generator for value batches - def mapper(a): - import itertools - - series_iter = iter(a) + def mapper(series_iter): # Need to materialize the first series list to get the keys first_series_list = next(series_iter) - keys = [first_series_list[o] for o in parsed_offsets[0][0]] + # Extract key Series from the first batch + key_series = [first_series_list[o] for o in parsed_offsets[0][0]] + + # Create generator for value Series lists (one list per batch) value_series_gen = ( [series_list[o] for o in parsed_offsets[0][1]] for series_list in itertools.chain((first_series_list,), series_iter) ) - return f(keys, value_series_gen) + # Flatten one level: yield from wrapper to return Iterator[[(df, arrow_type)]] + yield from f(key_series, value_series_gen) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: # We assume there is only one UDF here because grouped map doesn't