From 444f9a4b25107c79fa126ed4624090a7450f1231 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 15 Sep 2022 13:00:30 +0900 Subject: [PATCH 01/38] [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark --- .../spark/api/python/PythonRunner.scala | 2 + python/pyspark/rdd.py | 2 + .../pyspark/sql/pandas/_typing/__init__.pyi | 5 + python/pyspark/sql/pandas/functions.py | 2 + python/pyspark/sql/pandas/group_ops.py | 105 +++++++- python/pyspark/sql/pandas/serializers.py | 252 +++++++++++++++++- python/pyspark/sql/udf.py | 9 +- python/pyspark/worker.py | 116 ++++++++ .../UnsupportedOperationChecker.scala | 62 +++++ .../logical/pythonLogicalOperators.scala | 34 +++ .../apache/spark/sql/internal/SQLConf.scala | 47 ++++ .../spark/sql/RelationalGroupedDataset.scala | 31 +++ .../spark/sql/execution/SparkStrategies.scala | 22 ++ .../sql/execution/arrow/ArrowWriter.scala | 2 +- .../ApplyInPandasWithStatePythonRunner.scala | 197 ++++++++++++++ .../python/ApplyInPandasWithStateWriter.scala | 218 +++++++++++++++ .../python/FlatMapCoGroupsInPandasExec.scala | 4 +- .../python/FlatMapGroupsInPandasExec.scala | 2 +- .../FlatMapGroupsInPandasWithStateExec.scala | 217 +++++++++++++++ .../execution/python/PandasGroupUtils.scala | 5 +- .../streaming/IncrementalExecution.scala | 9 + .../execution/streaming/state/package.scala | 2 +- 22 files changed, 1332 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 5a13674e8bfb..7b31fa93c32e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -53,6 +53,7 @@ private[spark] object PythonEvalType { val SQL_MAP_PANDAS_ITER_UDF = 205 val SQL_COGROUPED_MAP_PANDAS_UDF = 206 val SQL_MAP_ARROW_ITER_UDF = 207 + val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -65,6 +66,7 @@ private[spark] object PythonEvalType { case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF" case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF" case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF" + case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE" } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7ef0014ae751..5f4f4d494e13 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -105,6 +105,7 @@ PandasMapIterUDFType, PandasCogroupedMapUDFType, ArrowMapIterUDFType, + PandasGroupedMapUDFWithStateType, ) from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import AtomicType, StructType @@ -147,6 +148,7 @@ class PythonEvalType: SQL_MAP_PANDAS_ITER_UDF: "PandasMapIterUDFType" = 205 SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206 SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207 + SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208 def portable_hash(x: Hashable) -> int: diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi index 27ac64a7238b..9f855c6c1151 100644 --- a/python/pyspark/sql/pandas/_typing/__init__.pyi +++ b/python/pyspark/sql/pandas/_typing/__init__.pyi @@ -30,6 +30,7 @@ from typing_extensions import Protocol, Literal from types import FunctionType from pyspark.sql._typing import LiteralType +from pyspark.sql.streaming.state import GroupStateImpl from pandas.core.frame import DataFrame as PandasDataFrame from pandas.core.series import Series as PandasSeries from numpy import ndarray as NDArray @@ -51,6 +52,7 @@ PandasScalarIterUDFType = Literal[204] PandasMapIterUDFType = Literal[205] PandasCogroupedMapUDFType = Literal[206] ArrowMapIterUDFType = Literal[207] +PandasGroupedMapUDFWithStateType = Literal[208] class PandasVariadicScalarToScalarFunction(Protocol): def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ... @@ -256,6 +258,9 @@ PandasGroupedMapFunction = Union[ Callable[[Any, DataFrameLike], DataFrameLike], ] +PandasGroupedMapFunctionWithState = Callable[[Any, Iterable[DataFrameLike], GroupStateImpl], + Iterable[DataFrameLike]] + class PandasVariadicGroupedAggFunction(Protocol): def __call__(self, *_: SeriesLike) -> LiteralType: ... diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 94fabdbb2959..d0f81e2f6335 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -369,6 +369,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, None, ]: # None means it should infer the type from type hints. @@ -402,6 +403,7 @@ def _create_pandas_udf(f, returnType, evalType): PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, ]: # In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is being triggered # at `apply` instead. diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 6178433573e9..f845f2466011 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -15,18 +15,20 @@ # limitations under the License. # import sys -from typing import List, Union, TYPE_CHECKING +from typing import List, Union, TYPE_CHECKING, cast import warnings from pyspark.rdd import PythonEvalType from pyspark.sql.column import Column from pyspark.sql.dataframe import DataFrame -from pyspark.sql.types import StructType +from pyspark.sql.streaming.state import GroupStateTimeout +from pyspark.sql.types import StructType, _parse_datatype_string if TYPE_CHECKING: from pyspark.sql.pandas._typing import ( GroupedMapPandasUserDefinedFunction, PandasGroupedMapFunction, + PandasGroupedMapFunctionWithState, PandasCogroupedMapFunction, ) from pyspark.sql.group import GroupedData @@ -216,6 +218,105 @@ def applyInPandas( jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) return DataFrame(jdf, self.session) + def applyInPandasWithState( + self, + func: "PandasGroupedMapFunctionWithState", + outputStructType: Union[StructType, str], + stateStructType: Union[StructType, str], + outputMode: str, + timeoutConf: str, + ) -> DataFrame: + """ + Applies the given function to each group of data, while maintaining a user-defined + per-group state. The result Dataset will represent the flattened record returned by the + function. + + For a streaming Dataset, the function will be invoked for each group repeatedly in every + trigger, and updates to each group's state will be saved across invocations. The function + will also be invoked for each timed-out state repeatedly. The sequence of the invocation + will be input data -> state timeout. When the function is invoked for state timeout, there + will be no data being presented. + + The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and + returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple + of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as + :class:`pyspark.sql.streaming.state.GroupStateImpl`. + + For each group, all columns are passed together as `pandas.DataFrame` to the user-function, + and the returned `pandas.DataFrame` across all invocations are combined as a + :class:`DataFrame`. Note that the user function should loop through and process all + elements in the iterator. The user function should not make a guess of the number of + elements in the iterator. + + The `outputStructType` should be a :class:`StructType` describing the schema of all + elements in returned value, `pandas.DataFrame`. The column labels of all elements in + returned value, `pandas.DataFrame` must either match the field names in the defined + schema if specified as strings, or match the field data types by position if not strings, + e.g. integer indices. + + The `stateStructType` should be :class:`StructType` describing the schema of user-defined + state. The value of state will be presented as a tuple, as well as the update should be + performed with the tuple. User defined types e.g. native Python class types are not + supported. Alternatively, you can pickle the data and produce the data as BinaryType, but + it is tied to the backward and forward compatibility of pickle in Python, and Spark itself + does not guarantee the compatibility. + + The length of each element in both input and returned value, `pandas.DataFrame`, can be + arbitrary. The length of iterator in both input and returned value can be also arbitrary. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + func : function + a Python native function to be called on every group. It should takes parameters + (key, Iterator[`pandas.DataFrame`], state) and returns Iterator[`pandas.DataFrame`]. + Note that the type of key is tuple, and the type of state is + :class:`pyspark.sql.streaming.state.GroupStateImpl`. + outputStructType : :class:`pyspark.sql.types.DataType` or str + the type of the output records. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + stateStructType : :class:`pyspark.sql.types.DataType` or str + the type of the user-defined state. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + outputMode : str + the output mode of the function. + timeoutConf : str + timeout configuration for groups that do not receive data for a while. valid values + are defined in :class:`pyspark.sql.streaming.state.GroupStateTimeout`. + """ + + from pyspark.sql import GroupedData + from pyspark.sql.functions import pandas_udf + + assert isinstance(self, GroupedData) + assert timeoutConf in [ + GroupStateTimeout.NoTimeout, + GroupStateTimeout.ProcessingTimeTimeout, + GroupStateTimeout.EventTimeTimeout, + ] + + if isinstance(outputStructType, str): + outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) + if isinstance(stateStructType, str): + stateStructType = cast(StructType, _parse_datatype_string(stateStructType)) + + udf = pandas_udf( + func, # type: ignore[call-overload] + returnType=outputStructType, + functionType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + ) + df = self._df + udf_column = udf(*[df[col] for col in df.columns]) + jdf = self._jgd.applyInPandasWithState( + udf_column._jc.expr(), + self.session._jsparkSession.parseDataType(outputStructType.json()), + self.session._jsparkSession.parseDataType(stateStructType.json()), + outputMode, + timeoutConf, + ) + return DataFrame(jdf, self.session) + def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps": """ Cogroups this group with another group so that we can run cogrouped operations. diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 992e82b403a1..e561e8723819 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -19,7 +19,11 @@ Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details. """ -from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer +import time + +from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer +from pyspark.sql.pandas.types import to_arrow_type +from pyspark.sql.types import StringType, StructType, BinaryType, StructField, LongType class SpecialLengths: @@ -371,3 +375,249 @@ def load_stream(self, stream): raise ValueError( "Invalid number of pandas.DataFrames in group {0}".format(dataframes_in_group) ) + + +class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): + + def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema, + soft_limit_bytes_per_batch, min_data_count_for_sample, + soft_timeout_millis_purge_batch): + super(ApplyInPandasWithStateSerializer, self).__init__( + timezone, safecheck, assign_cols_by_name) + self.pickleSer = CPickleSerializer() + self.utf8_deserializer = UTF8Deserializer() + self.state_object_schema = state_object_schema + + self.result_state_df_type = StructType([ + StructField('properties', StringType()), + StructField('keyRowAsUnsafe', BinaryType()), + StructField('object', BinaryType()), + StructField('oldTimeoutTimestamp', LongType()), + ]) + + self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type) + self.soft_limit_bytes_per_batch = soft_limit_bytes_per_batch + self.min_data_count_for_sample = min_data_count_for_sample + self.soft_timeout_millis_purge_batch = soft_timeout_millis_purge_batch + + def load_stream(self, stream): + import pyarrow as pa + import json + from itertools import groupby + from pyspark.sql.streaming.state import GroupStateImpl + + def gen_data_and_state(batches): + state_for_current_group = None + + for batch in batches: + batch_schema = batch.schema + data_schema = pa.schema([batch_schema[i] for i in range(0, len(batch_schema) - 1)]) + state_schema = pa.schema([batch_schema[-1], ]) + + batch_columns = batch.columns + data_columns = batch_columns[0:-1] + state_column = batch_columns[-1] + + data_batch = pa.RecordBatch.from_arrays(data_columns, schema=data_schema) + state_batch = pa.RecordBatch.from_arrays([state_column, ], schema=state_schema) + + state_arrow = pa.Table.from_batches([state_batch]).itercolumns() + state_pandas = [self.arrow_to_pandas(c) for c in state_arrow][0] + + for state_idx in range(0, len(state_pandas)): + state_info_col = state_pandas.iloc[state_idx] + + if not state_info_col: + # no more data with grouping key + state + break + + state_info_col_properties = state_info_col['properties'] + state_info_col_key_row = state_info_col['keyRowAsUnsafe'] + state_info_col_object = state_info_col['object'] + + data_start_offset = state_info_col['startOffset'] + num_data_rows = state_info_col['numRows'] + is_last_chunk = state_info_col['isLastChunk'] + + state_properties = json.loads(state_info_col_properties) + if state_info_col_object: + state_object = self.pickleSer.loads(state_info_col_object) + else: + state_object = None + state_properties["optionalValue"] = state_object + + if state_for_current_group: + # use the state, we already have state for same group and there should be + # some data in same group being processed earlier + state = state_for_current_group + else: + # there is no state being stored for same group, construct one + state = GroupStateImpl(keyAsUnsafe=state_info_col_key_row, + valueSchema=self.state_object_schema, + **state_properties) + + if is_last_chunk: + # discard the state being cached for same group + state_for_current_group = None + elif not state_for_current_group: + # there's no cached state but expected to have additional data in same group + # cache the current state + state_for_current_group = state + + data_batch_for_group = data_batch.slice(data_start_offset, num_data_rows) + data_arrow = pa.Table.from_batches([data_batch_for_group]).itercolumns() + + data_pandas = [self.arrow_to_pandas(c) for c in data_arrow] + + # state info + yield (data_pandas, state, ) + + batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + + data_state_generator = gen_data_and_state(batches) + + # state will be same object for same grouping key + for state, data in groupby(data_state_generator, key=lambda x: x[1]): + yield (data, state, ) + + def dump_stream(self, iterator, stream): + """ + Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. + This should be sent after creating the first record batch so in case of an error, it can + be sent back to the JVM before the Arrow stream starts. + """ + + def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_data_cnt): + """ + Arrow RecordBatch requires all columns to have all same number of rows. + Insert empty data for state/data with less elements to compensate. + """ + + import pandas as pd + import pyarrow as pa + + max_data_cnt = max(pdf_data_cnt, state_data_cnt) + + empty_row_cnt_in_data = max_data_cnt - pdf_data_cnt + empty_row_cnt_in_state = max_data_cnt - state_data_cnt + + empty_rows_pdf = pd.DataFrame( + dict.fromkeys(pa.schema(pdf_schema).names), + index=[x for x in range(0, empty_row_cnt_in_data)]) + empty_rows_state = pd.DataFrame( + columns=['properties', 'keyRowAsUnsafe', 'object', 'oldTimeoutTimestamp'], + index=[x for x in range(0, empty_row_cnt_in_state)]) + + pdfs.append(empty_rows_pdf) + state_pdfs.append(empty_rows_state) + + merged_pdf = pd.concat(pdfs, ignore_index=True) + merged_state_pdf = pd.concat(state_pdfs, ignore_index=True) + + return self._create_batch([ + (merged_pdf, pdf_schema), + (merged_state_pdf, self.result_state_pdf_arrow_type)]) + + def init_stream_yield_batches(): + import pandas as pd + + should_write_start_length = True + + pdfs = [] + state_pdfs = [] + return_schema = None + + pdf_data_cnt = 0 + state_data_cnt = 0 + + sampled_data_size_per_row = 0 + + last_purged_time_ns = time.time_ns() + + for data in iterator: + packaged_result = data[0] + + pdf_iter = packaged_result[0][0] + state = packaged_result[0][1] + # this won't change across batches + return_schema = packaged_result[1] + + for pdf in pdf_iter: + if len(pdf) > 0: + pdf_data_cnt += len(pdf) + pdfs.append(pdf) + + if sampled_data_size_per_row == 0 and \ + pdf_data_cnt > self.min_data_count_for_sample: + memory_usages = [p.memory_usage(deep=True).sum() for p in pdfs] + sampled_data_size_per_row = sum(memory_usages) / pdf_data_cnt + + # This effectively works after the sampling has completed, size we multiply + # by 0 if the sampling is still in progress. + batch_over_limit_on_size = (sampled_data_size_per_row * pdf_data_cnt) >= \ + self.soft_limit_bytes_per_batch + + if batch_over_limit_on_size: + batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, + state_pdfs, state_data_cnt) + + pdfs = [] + state_pdfs = [] + pdf_data_cnt = 0 + state_data_cnt = 0 + last_purged_time_ns = time.time_ns() + + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + + yield batch + + # pick up state for only last chunk as state should have been updated so far + state_properties = state.json().encode("utf-8") + state_key_row_as_binary = state._keyAsUnsafe + state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) + state_old_timeout_timestamp = state.oldTimeoutTimestamp + + state_dict = { + 'properties': [state_properties, ], + 'keyRowAsUnsafe': [state_key_row_as_binary, ], + 'object': [state_object, ], + 'oldTimeoutTimestamp': [state_old_timeout_timestamp, ], + } + + state_pdf = pd.DataFrame.from_dict(state_dict) + + state_pdfs.append(state_pdf) + state_data_cnt += 1 + + cur_time_ns = time.time_ns() + is_timed_out_on_purge = ((cur_time_ns - last_purged_time_ns) // 1000000) >= \ + self.soft_timeout_millis_purge_batch + if is_timed_out_on_purge: + batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, + state_pdfs, state_data_cnt) + + pdfs = [] + state_pdfs = [] + pdf_data_cnt = 0 + state_data_cnt = 0 + last_purged_time_ns = cur_time_ns + + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + + yield batch + + # end of loop, we may have remaining data + if pdf_data_cnt > 0 or state_data_cnt > 0: + batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, + state_pdfs, state_data_cnt) + + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + + yield batch + + return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 6a01e399d040..da9a245bb711 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -144,20 +144,23 @@ def returnType(self) -> DataType: "Invalid return type with scalar Pandas UDFs: %s is " "not supported" % str(self._returnType_placeholder) ) - elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + elif ( + self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF + or self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE + ): if isinstance(self._returnType_placeholder, StructType): try: to_arrow_type(self._returnType_placeholder) except TypeError: raise NotImplementedError( "Invalid return type with grouped map Pandas UDFs or " - "at groupby.applyInPandas: %s is not supported" + "at groupby.applyInPandas(WithState): %s is not supported" % str(self._returnType_placeholder) ) else: raise TypeError( "Invalid return type for grouped map Pandas " - "UDFs or at groupby.applyInPandas: return type must be a " + "UDFs or at groupby.applyInPandas(WithState): return type must be a " "StructType." ) elif ( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index c486b7bed1d8..59d2f5b9d61e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,6 +23,7 @@ import time from inspect import currentframe, getframeinfo, getfullargspec import importlib +import json # 'resource' is a Unix specific module. has_resource_module = True @@ -57,6 +58,7 @@ ArrowStreamPandasUDFSerializer, CogroupUDFSerializer, ArrowStreamUDFSerializer, + ApplyInPandasWithStateSerializer, ) from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import StructType @@ -207,6 +209,61 @@ def wrapped(key_series, value_series): return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] +def wrap_grouped_map_pandas_udf_with_state(f, return_type): + def wrapped(key_series, value_series_gen, state): + import pandas as pd + + key = tuple(s[0] for s in key_series) + + if state.hasTimedOut: + # Timeout processing pass empty iterator. Here we return an empty DataFrame instead. + values = [pd.DataFrame(columns=pd.concat(next(value_series_gen), axis=1).columns), ] + else: + values = (pd.concat(x, axis=1) for x in value_series_gen) + + result_iter = f(key, values, state) + + def verify_element(result): + if not isinstance(result, pd.DataFrame): + raise TypeError( + "The type of element in return iterator of the user-defined function " + "should be pandas.DataFrame, but is {}".format(type(result)) + ) + # the number of columns of result have to match the return type + # but it is fine for result to have no columns at all if it is empty + if not ( + len(result.columns) == len(return_type) or + len(result.columns) == 0 and result.empty + ): + raise RuntimeError( + "Number of columns of the element (pandas.DataFrame) in return iterator " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(return_type), len(result.columns)) + ) + + return result + + if isinstance(result_iter, pd.DataFrame): + raise TypeError( + "Return type of the user-defined function should be " + "iterable of pandas.DataFrame, but is {}".format(type(result_iter)) + ) + + try: + iter(result_iter) + except TypeError: + raise TypeError( + "Return type of the user-defined function should be " + "iterable, but is {}".format(type(result_iter)) + ) + + result_iter_with_validation = (verify_element(x) for x in result_iter) + + return (result_iter_with_validation, state, ) + + return lambda k, v, s: [(wrapped(k, v, s), to_arrow_type(return_type))] + + def wrap_grouped_agg_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) @@ -311,6 +368,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec) @@ -336,6 +395,7 @@ def read_udfs(pickleSer, infile, eval_type): PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, ): # Load conf used for pandas_udf evaluation @@ -345,6 +405,10 @@ def read_udfs(pickleSer, infile, eval_type): v = utf8_deserializer.loads(infile) runner_conf[k] = v + state_object_schema = None + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) + # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = ( @@ -361,6 +425,29 @@ def read_udfs(pickleSer, infile, eval_type): if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + soft_limit_bytes_per_batch = runner_conf.get( + "spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch", + (64 * 1024 * 1024) + ) + soft_limit_bytes_per_batch = int(soft_limit_bytes_per_batch) + + min_data_count_for_sample = runner_conf.get( + "spark.sql.execution.applyInPandasWithState.minDataCountForSample", 100 + ) + min_data_count_for_sample = int(min_data_count_for_sample) + + soft_timeout_millis_purge_batch = runner_conf.get( + "spark.sql.execution.applyInPandasWithState.softTimeoutPurgeBatch", 100 + ) + soft_timeout_millis_purge_batch = int(soft_timeout_millis_purge_batch) + + ser = ApplyInPandasWithStateSerializer( + timezone, safecheck, assign_cols_by_name, + state_object_schema, + soft_limit_bytes_per_batch, + min_data_count_for_sample, + soft_timeout_millis_purge_batch) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: ser = ArrowStreamUDFSerializer() else: @@ -486,6 +573,35 @@ def mapper(a): vals = [a[o] for o in parsed_offsets[0][1]] return f(keys, vals) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + # We assume there is only one UDF here because grouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + + # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to + # distinguish between grouping attributes and data attributes + arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + parsed_offsets = extract_key_value_indexes(arg_offsets) + + def mapper(a): + from itertools import tee + + state = a[1] + data_gen = (x[0] for x in a[0]) + + # We know there should be at least one item in the iterator/generator. + # We want to peek the first element to construct the key, hence applying + # tee to construct the key while we retain another iterator/generator + # for values. + keys_gen, values_gen = tee(data_gen) + keys_elem = next(keys_gen) + keys = [keys_elem[o] for o in parsed_offsets[0][0]] + + # This must be generator comprehension - do not materialize. + vals = ([x[o] for o in parsed_offsets[0][1]] for x in values_gen) + + return f(keys, vals, state) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: # We assume there is only one UDF here because cogrouped map doesn't # support combining multiple UDFs. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index c11ce7d3b90f..99ba3802097b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -64,6 +64,7 @@ object UnsupportedOperationChecker extends Logging { case s: Aggregate if s.isStreaming => true case _ @ Join(left, right, _, _, _) if left.isStreaming && right.isStreaming => true case f: FlatMapGroupsWithState if f.isStreaming => true + case f: FlatMapGroupsInPandasWithState if f.isStreaming => true case d: Deduplicate if d.isStreaming => true case _ => false } @@ -142,6 +143,17 @@ object UnsupportedOperationChecker extends Logging { " or the output mode is not append on a streaming DataFrames/Datasets")(plan) } + val applyInPandasWithStates = plan.collect { + case f: FlatMapGroupsInPandasWithState if f.isStreaming => f + } + + // Disallow multiple `applyInPandasWithState`s. + if (applyInPandasWithStates.size >= 2) { + throwError( + "Multiple applyInPandasWithStates are not supported on a streaming " + + "DataFrames/Datasets")(plan) + } + // Disallow multiple streaming aggregations val aggregates = collectStreamingAggregates(plan) @@ -311,6 +323,56 @@ object UnsupportedOperationChecker extends Logging { } } + // applyInPandasWithState + case m: FlatMapGroupsInPandasWithState if m.isStreaming => + // Check compatibility with output modes and aggregations in query + val aggsInQuery = collectStreamingAggregates(plan) + + if (aggsInQuery.isEmpty) { + // applyInPandasWithState without aggregation: operation's output mode must + // match query output mode + m.outputMode match { + case InternalOutputModes.Update if outputMode != InternalOutputModes.Update => + throwError( + "applyInPandasWithState in update mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case InternalOutputModes.Append if outputMode != InternalOutputModes.Append => + throwError( + "applyInPandasWithState in append mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case _ => + } + } else { + // applyInPandasWithState with aggregation: update operation mode not allowed, and + // *groupsWithState after aggregation not allowed + if (m.outputMode == InternalOutputModes.Update) { + throwError( + "applyInPandasWithState in update mode is not supported with " + + "aggregation on a streaming DataFrame/Dataset") + } else if (collectStreamingAggregates(m).nonEmpty) { + throwError( + "applyInPandasWithState in append mode is not supported after " + + "aggregation on a streaming DataFrame/Dataset") + } + } + + // Check compatibility with timeout configs + if (m.timeout == EventTimeTimeout) { + // With event time timeout, watermark must be defined. + val watermarkAttributes = m.child.output.collect { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a + } + if (watermarkAttributes.isEmpty) { + throwError( + "Watermark must be specified in the query using " + + "'[Dataset/DataFrame].withWatermark()' for using event-time timeout in a " + + "applyInPandasWithState. Event-time timeout not supported without " + + "watermark.")(plan) + } + } + case d: Deduplicate if collectStreamingAggregates(d).nonEmpty => throwError("dropDuplicates is not supported after aggregation on a " + "streaming DataFrame/Dataset") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index c2f74b350834..e97ff7808f17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.types.StructType /** * FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame. @@ -98,6 +100,38 @@ case class FlatMapCoGroupsInPandas( copy(left = newLeft, right = newRight) } +/** + * Similar with [[FlatMapGroupsWithState]]. Applies func to each unique group + * in `child`, based on the evaluation of `groupingAttributes`, + * while using state data. + * `functionExpr` is invoked with an pandas DataFrame representation and the + * grouping key (tuple). + * + * @param functionExpr function called on each group + * @param groupingAttributes used to group the data + * @param outputAttrs used to define the output rows + * @param stateType used to serialize/deserialize state before calling `functionExpr` + * @param outputMode the output mode of `func` + * @param timeout used to timeout groups that have not received data in a while + * @param child logical plan of the underlying data + */ +case class FlatMapGroupsInPandasWithState( + functionExpr: Expression, + groupingAttributes: Seq[Attribute], + outputAttrs: Seq[Attribute], + stateType: StructType, + outputMode: OutputMode, + timeout: GroupStateTimeout, + child: LogicalPlan) extends UnaryNode { + + override def output: Seq[Attribute] = outputAttrs + + override def producedAttributes: AttributeSet = AttributeSet(outputAttrs) + + override protected def withNewChildInternal( + newChild: LogicalPlan): FlatMapGroupsInPandasWithState = copy(child = newChild) +} + trait BaseEvalPython extends UnaryNode { def udfs: Seq[PythonUDF] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index de25c19a26eb..7f36417dfe0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2705,6 +2705,44 @@ object SQLConf { .booleanConf .createWithDefault(false) + val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH = + buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch") + .internal() + .doc("When using applyInPandasWithState, set a soft limit of the accumulated size of " + + "records that can be written to a single ArrowRecordBatch in memory. This is used to " + + "restrict the amount of memory being used to materialize the data in both executor and " + + "Python worker. The accumulated size of records are calculated via sampling a set of " + + "records. Splitting the ArrowRecordBatch is performed per record, so unless a record " + + "is quite huge, the size of constructed ArrowRecordBatch will be around the " + + "configured value.") + .version("3.4.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("64MB") + + val MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE = + buildConf("spark.sql.execution.applyInPandasWithState.minDataCountForSample") + .internal() + .doc("When using applyInPandasWithState, specify the minimum number of records to sample " + + "the size of record. The size being retrieved from sampling will be used to estimate " + + "the accumulated size of records. Note that limiting by size does not work if the " + + "number of records are less than the configured value. For such case, ArrowRecordBatch " + + "will only be split for soft timeout.") + .version("3.4.0") + .intConf + .createWithDefault(100) + + val MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH = + buildConf("spark.sql.execution.applyInPandasWithState.softTimeoutPurgeBatch") + .internal() + .doc("When using applyInPandasWithState, specify the soft timeout for purging the " + + "ArrowRecordBatch. If batching records exceeds the timeout, Spark will force splitting " + + "the ArrowRecordBatch regardless of estimated size. This config ensures the receiver " + + "of data (both executor and Python worker) to not wait indefinitely for sender to " + + "complete the ArrowRecordBatch, which may hurt both throughput and latency.") + .version("3.4.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("100ms") + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") .internal() .doc("When true, the apply function of the rule verifies whether the right node of the" + @@ -4529,6 +4567,15 @@ class SQLConf extends Serializable with Logging { def arrowSafeTypeConversion: Boolean = getConf(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION) + def softLimitBytesPerBatchInApplyInPandasWithState: Long = + getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH) + + def minDataCountForSampleInApplyInPandasWithState: Int = + getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE) + + def softTimeoutMillisPurgeBatchInApplyInPandasWithState: Long = + getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH) + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 989ee3252187..69eb8101abf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -30,9 +30,11 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{NumericType, StructType} /** @@ -620,6 +622,35 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + private[sql] def applyInPandasWithState( + func: PythonUDF, + outputStructType: StructType, + stateStructType: StructType, + outputModeStr: String, + timeoutConfStr: String): DataFrame = { + val timeoutConf = org.apache.spark.sql.execution.streaming + .GroupStateImpl.groupStateTimeoutFromString(timeoutConfStr) + val outputMode = InternalOutputModes(outputModeStr) + if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { + throw new IllegalArgumentException("The output mode of function should be append or update") + } + val groupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + val groupingAttrs = groupingNamedExpressions.map(_.toAttribute) + val outputAttrs = outputStructType.toAttributes + val plan = FlatMapGroupsInPandasWithState( + func, + groupingAttrs, + outputAttrs, + stateStructType, + outputMode, + timeoutConf, + child = df.logicalPlan) + Dataset.ofRows(df.sparkSession, plan) + } + override def toString: String = { val builder = new StringBuilder builder.append("RelationalGroupedDataset: [grouping expressions: [") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6104104c7bea..062eae128f61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -684,6 +684,25 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Strategy to convert [[FlatMapGroupsInPandasWithState]] logical operator to physical operator + * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. + */ + object FlatMapGroupsInPandasWithStateStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case FlatMapGroupsInPandasWithState( + func, groupAttr, outputAttr, stateType, outputMode, timeout, child) => + val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + val execPlan = python.FlatMapGroupsInPandasWithStateExec( + func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout, + batchTimestampMs = None, eventTimeWatermark = None, planLater(child) + ) + execPlan :: Nil + case _ => + Nil + } + } + /** * Strategy to convert EvalPython logical operator to physical operator. */ @@ -793,6 +812,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout, hasInitialState, planLater(initialState), planLater(child) ) :: Nil + case _: FlatMapGroupsInPandasWithState => + throw new UnsupportedOperationException( + "applyInPandasWithState is unsupported in batch query. Use applyInPandas instead.") case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 7abca5f0e332..34e128a4925f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -44,7 +44,7 @@ object ArrowWriter { new ArrowWriter(root, children.toArray) } - private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { + private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { val field = vector.getField() (ArrowUtils.fromArrowField(field), vector) match { case (BooleanType, vector: BitVector) => new BooleanWriter(vector) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala new file mode 100644 index 000000000000..213c9f4e712b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.api.python._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER} +import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} + + +/** + * [[ArrowPythonRunner]] with [[org.apache.spark.sql.streaming.GroupState]]. + */ +class ApplyInPandasWithStatePythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + inputSchema: StructType, + override protected val timeZoneId: String, + initialWorkerConf: Map[String, String], + stateEncoder: ExpressionEncoder[Row], + keySchema: StructType, + valueSchema: StructType, + stateValueSchema: StructType, + softLimitBytesPerBatch: Long, + minDataCountForSample: Int, + softTimeoutMillsPurgeBatch: Long) + extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets) + with PythonArrowInput[InType] + with PythonArrowOutput[OutType] { + + override protected val schema: StructType = inputSchema.add("!__state__!", STATE_METADATA_SCHEMA) + + override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + + override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize + require( + bufferSize >= 4, + "Pandas execution requires more than 4 bytes. Please set higher buffer. " + + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") + + override protected val workerConf: Map[String, String] = initialWorkerConf + + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH.key -> + softLimitBytesPerBatch.toString) + + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE.key -> + minDataCountForSample.toString) + + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH.key -> + softTimeoutMillsPurgeBatch.toString) + + private val stateRowDeserializer = stateEncoder.createDeserializer() + + override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { + super.handleMetadataBeforeExec(stream) + // Also write the schema for state value + PythonRDD.writeUTF(stateValueSchema.json, stream) + } + + protected def writeIteratorToArrowStream( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[InType]): Unit = { + val w = new ApplyInPandasWithStateWriter(root, writer, softLimitBytesPerBatch, + minDataCountForSample, softTimeoutMillsPurgeBatch) + + while (inputIterator.hasNext) { + val (keyRow, groupState, dataIter) = inputIterator.next() + assert(dataIter.hasNext, "should have at least one data row!") + w.startNewGroup(keyRow, groupState) + + while (dataIter.hasNext) { + val dataRow = dataIter.next() + w.writeRow(dataRow) + } + + w.finalizeGroup() + } + + w.finalizeData() + } + + protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OutType = { + // This should at least have one row for state. Also, we ensure that all columns across + // data and state metadata have same number of rows, which is required by Arrow record + // batch. + assert(batch.numRows() > 0) + assert(schema.length == 2) + + def getColumnarBatchForStructTypeColumn( + batch: ColumnarBatch, + ordinal: Int, + expectedType: StructType): ColumnarBatch = { + // UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(ordinal).asInstanceOf[ArrowColumnVector] + val dataType = schema(ordinal).dataType.asInstanceOf[StructType] + assert(dataType.sameType(expectedType)) + + val outputVectors = dataType.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + + flattenedBatch + } + + def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = { + val dataBatch = getColumnarBatchForStructTypeColumn(batch, 0, valueSchema) + dataBatch.rowIterator.asScala.flatMap { row => + if (row.isNullAt(0)) { + // The entire row in record batch seems to be for state metadata. + None + } else { + Some(row) + } + } + } + + def constructIterForState(batch: ColumnarBatch): Iterator[OutTypeForState] = { + val stateMetadataBatch = getColumnarBatchForStructTypeColumn(batch, 1, + STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER) + + stateMetadataBatch.rowIterator().asScala.flatMap { row => + implicit val formats = org.json4s.DefaultFormats + + if (row.isNullAt(0)) { + // The entire row in record batch seems to be for data. + None + } else { + // NOTE: See StateReaderIterator.STATE_METADATA_SCHEMA for the schema. + val propertiesAsJson = parse(row.getUTF8String(0).toString) + val keyRowAsUnsafeAsBinary = row.getBinary(1) + val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) + keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, keyRowAsUnsafeAsBinary.length) + val maybeObjectRow = if (row.isNullAt(2)) { + None + } else { + val pickledStateValue = row.getBinary(2) + Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema, + stateRowDeserializer)) + } + val oldTimeoutTimestamp = row.getLong(3) + + Some((keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson), + oldTimeoutTimestamp)) + } + } + } + + (constructIterForState(batch), constructIterForData(batch)) + } +} + +object ApplyInPandasWithStatePythonRunner { + type InType = (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow]) + type OutTypeForState = (UnsafeRow, GroupStateImpl[Row], Long) + type OutType = (Iterator[OutTypeForState], Iterator[InternalRow]) + + val STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER: StructType = StructType( + Array( + StructField("properties", StringType), + StructField("keyRowAsUnsafe", BinaryType), + StructField("object", BinaryType), + StructField("oldTimeoutTimestamp", LongType) + ) + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala new file mode 100644 index 000000000000..7278ad410740 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot} +import org.apache.arrow.vector.ipc.ArrowStreamWriter + +import org.apache.spark.sql.Row +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.types.{BinaryType, BooleanType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String + + +class ApplyInPandasWithStateWriter( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + softLimitBytesPerBatch: Long, + minDataCountForSample: Int, + softTimeoutMillsPurgeBatch: Long) { + + import ApplyInPandasWithStateWriter._ + + // We logically group the columns by family and initialize writer separately, since it's + // lot more easier and probably performant to write the row directly rather than + // projecting the row to match up with the overall schema. + // + // The number of data rows and state metadata rows can be different which seems to matter + // for Arrow RecordBatch, so we append empty rows to cover it. + // + // We always produce at least one data row per grouping key whereas we only produce one + // state metadata row per grouping key, so we only need to fill up the empty rows in + // state metadata side. + private val arrowWriterForData = createArrowWriter(root.getFieldVectors.asScala.dropRight(1)) + private val arrowWriterForState = createArrowWriter(root.getFieldVectors.asScala.takeRight(1)) + + // We apply bin-packing the data from multiple groups into one Arrow RecordBatch to + // gain the performance. In many cases, the amount of data per grouping key is quite + // small, which does not seem to maximize the benefits of using Arrow. + // + // We have to split the record batch down to each group in Python worker to convert the + // data for group to Pandas, but hopefully, Arrow RecordBatch provides the way to split + // the range of data and give a view, say, "zero-copy". To help splitting the range for + // data, we provide the "start offset" and the "number of data" in the state metadata. + // + // Pretty sure we don't bin-pack all groups into a single record batch. We have a soft + // limit on the size - it's not a hard limit since we allow current group to write all + // data even it's going to exceed the limit. + // + // We perform some basic sampling for data to guess the size of the data very roughly, + // and simply multiply by the number of data to estimate the size. We extract the size of + // data from the record batch rather than UnsafeRow, as we don't hold the memory for + // UnsafeRow once we write to the record batch. If there is a memory bound here, it + // should come from record batch. + // + // In the meanwhile, we don't also want to let the current record batch collect the data + // indefinitely, since we are pipelining the process between executor and python worker. + // Python worker won't process any data if executor is not yet finalized a record + // batch, which defeats the purpose of pipelining. To address this, we also introduce + // timeout for constructing a record batch. This is a soft limit indeed as same as limit + // on the size - we allow current group to write all data even it's timed-out. + + private var numRowsForCurGroup = 0 + private var startOffsetForCurGroup = 0 + private var totalNumRowsForBatch = 0 + private var totalNumStatesForBatch = 0 + + private var sampledDataSizePerRow = 0 + private var lastBatchPurgedMillis = System.currentTimeMillis() + + private var currentGroupKeyRow: UnsafeRow = _ + private var currentGroupState: GroupStateImpl[Row] = _ + + def startNewGroup(keyRow: UnsafeRow, groupState: GroupStateImpl[Row]): Unit = { + currentGroupKeyRow = keyRow + currentGroupState = groupState + } + + def writeRow(dataRow: InternalRow): Unit = { + // Currently, this only works when the number of rows are greater than the minimum + // data count for sampling. And we technically have no way to pick some rows from + // record batch and measure the size of data, hence we leverage all data in current + // record batch. We only sample once as it could be costly. + if (sampledDataSizePerRow == 0 && totalNumRowsForBatch > minDataCountForSample) { + sampledDataSizePerRow = arrowWriterForData.sizeInBytes() / totalNumRowsForBatch + } + + // If it exceeds the condition of batch (only size, not about timeout) and + // there is more data for the same group, flush and construct a new batch. + + // The soft-limit on size effectively works after the sampling has completed, since we + // multiply the number of rows by 0 if the sampling is still in progress. + + if (sampledDataSizePerRow * totalNumRowsForBatch >= softLimitBytesPerBatch) { + // Provide state metadata row as intermediate + val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState, + startOffsetForCurGroup, numRowsForCurGroup, isLastChunk = false) + arrowWriterForState.write(stateInfoRow) + totalNumStatesForBatch += 1 + + finalizeCurrentArrowBatch() + } + + arrowWriterForData.write(dataRow) + numRowsForCurGroup += 1 + totalNumRowsForBatch += 1 + } + + def finalizeGroup(): Unit = { + // Provide state metadata row + val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState, + startOffsetForCurGroup, numRowsForCurGroup, isLastChunk = true) + arrowWriterForState.write(stateInfoRow) + totalNumStatesForBatch += 1 + + // The start offset for next group would be same as the total number of rows for batch, + // unless the next group starts with new batch. + startOffsetForCurGroup = totalNumRowsForBatch + + // The soft-limit on timeout applies on finalization of each group. + if (System.currentTimeMillis() - lastBatchPurgedMillis > softTimeoutMillsPurgeBatch) { + finalizeCurrentArrowBatch() + } + } + + def finalizeData(): Unit = { + if (numRowsForCurGroup > 0) { + // We still have some rows in the current record batch. Need to flush them as well. + finalizeCurrentArrowBatch() + } + } + + private def createArrowWriter(fieldVectors: Seq[FieldVector]): ArrowWriter = { + val children = fieldVectors.map { vector => + vector.allocateNew() + createFieldWriter(vector) + } + + new ArrowWriter(root, children.toArray) + } + + private def buildStateInfoRow( + keyRow: UnsafeRow, + groupState: GroupStateImpl[Row], + startOffset: Int, + numRows: Int, + isLastChunk: Boolean): InternalRow = { + // NOTE: see ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA + val stateUnderlyingRow = new GenericInternalRow( + Array[Any]( + UTF8String.fromString(groupState.json()), + keyRow.getBytes, + groupState.getOption.map(PythonSQLUtils.toPyRow).orNull, + startOffset, + numRows, + isLastChunk + ) + ) + new GenericInternalRow(Array[Any](stateUnderlyingRow)) + } + + private def finalizeCurrentArrowBatch(): Unit = { + val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch + (0 until remainingEmptyStateRows).foreach { _ => + arrowWriterForState.write(EMPTY_STATE_METADATA_ROW) + } + + arrowWriterForState.finish() + arrowWriterForData.finish() + writer.writeBatch() + arrowWriterForState.reset() + arrowWriterForData.reset() + + startOffsetForCurGroup = 0 + numRowsForCurGroup = 0 + totalNumRowsForBatch = 0 + totalNumStatesForBatch = 0 + lastBatchPurgedMillis = System.currentTimeMillis() + } +} + +object ApplyInPandasWithStateWriter { + val STATE_METADATA_SCHEMA: StructType = StructType( + Array( + StructField("properties", StringType), + StructField("keyRowAsUnsafe", BinaryType), + StructField("object", BinaryType), + StructField("startOffset", IntegerType), + StructField("numRows", IntegerType), + StructField("isLastChunk", BooleanType) + ) + ) + + // To avoid initializing a new row for empty state metadata row. + val EMPTY_STATE_METADATA_ROW = new GenericInternalRow( + Array[Any](null, null, null, null, null, null)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index e830ea6b5466..b39787b12a48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -78,8 +78,8 @@ case class FlatMapCoGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { - val (leftDedup, leftArgOffsets) = resolveArgOffsets(left, leftGroup) - val (rightDedup, rightArgOffsets) = resolveArgOffsets(right, rightGroup) + val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup) + val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, rightGroup) // Map cogrouped rows to ArrowPythonRunner results, Only execute if partition is not empty left.execute().zipPartitions(right.execute()) { (leftData, rightData) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 3a3a6022f998..f0e815e966e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -75,7 +75,7 @@ case class FlatMapGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes) + val (dedupAttributes, argOffsets) = resolveArgOffsets(child.output, groupingAttributes) // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala new file mode 100644 index 000000000000..098d97af71ca --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.plans.physical.Distribution +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.CompletionIterator + +/** + * Physical operator for executing + * [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]] + * + * @param functionExpr function called on each group + * @param groupingAttributes used to group the data + * @param outAttributes used to define the output rows + * @param stateType used to serialize/deserialize state before calling `functionExpr` + * @param stateInfo `StatefulOperatorStateInfo` to identify the state store for a given operator. + * @param stateFormatVersion the version of state format. + * @param outputMode the output mode of `functionExpr` + * @param timeoutConf used to timeout groups that have not received data in a while + * @param batchTimestampMs processing timestamp of the current batch. + * @param eventTimeWatermark event time watermark for the current batch + * @param child logical plan of the underlying data + */ +case class FlatMapGroupsInPandasWithStateExec( + functionExpr: Expression, + groupingAttributes: Seq[Attribute], + outAttributes: Seq[Attribute], + stateType: StructType, + stateInfo: Option[StatefulOperatorStateInfo], + stateFormatVersion: Int, + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + batchTimestampMs: Option[Long], + eventTimeWatermark: Option[Long], + child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { + + // TODO(SPARK-XXXXX): Add the support of initial state. + override protected val initialStateDeserializer: Expression = null + override protected val initialStateGroupAttrs: Seq[Attribute] = null + override protected val initialStateDataAttrs: Seq[Attribute] = null + override protected val initialState: SparkPlan = null + override protected val hasInitialState: Boolean = false + + override protected val stateEncoder: ExpressionEncoder[Any] = + RowEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]] + + override def output: Seq[Attribute] = outAttributes + + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + + private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func + private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) + private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets( + groupingAttributes ++ child.output, groupingAttributes) + private lazy val unsafeProj = UnsafeProjection.create(dedupAttributes, child.output) + + override def requiredChildDistribution: Seq[Distribution] = + StatefulOperatorPartitioning.getCompatibleDistribution( + groupingAttributes, getStateInfo, conf) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( + groupingAttributes.map(SortOrder(_, Ascending))) + + override def shortName: String = "applyInPandasWithState" + + override protected def withNewChildInternal( + newChild: SparkPlan): FlatMapGroupsInPandasWithStateExec = copy(child = newChild) + + override def createInputProcessor( + store: StateStore): InputProcessor = new InputProcessor(store: StateStore) { + + override def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) + val processIter = groupedIter.map { case (keyRow, valueRowIter) => + val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] + val stateData = stateManager.getState(store, keyUnsafeRow) + (keyUnsafeRow, stateData, valueRowIter.map(unsafeProj)) + } + + process(processIter, hasTimedOut = false) + } + + override def processNewDataWithInitialState( + childDataIter: Iterator[InternalRow], + initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = { + throw new UnsupportedOperationException("Should not reach here!") + } + + override def processTimedOutState(): Iterator[InternalRow] = { + if (isTimeoutEnabled) { + val timeoutThreshold = timeoutConf match { + case ProcessingTimeTimeout => batchTimestampMs.get + case EventTimeTimeout => eventTimeWatermark.get + case _ => + throw new IllegalStateException( + s"Cannot filter timed out keys for $timeoutConf") + } + val timingOutPairs = stateManager.getAllState(store).filter { state => + state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold + } + + val processIter = timingOutPairs.map { stateData => + val joinedKeyRow = unsafeProj( + new JoinedRow( + stateData.keyRow, + new GenericInternalRow(Array.fill(dedupAttributes.length)(null: Any)))) + + (stateData.keyRow, stateData, Iterator.single(joinedKeyRow)) + } + + process(processIter, hasTimedOut = true) + } else Iterator.empty + } + + private def process( + iter: Iterator[(UnsafeRow, StateData, Iterator[InternalRow])], + hasTimedOut: Boolean): Iterator[InternalRow] = { + val runner = new ApplyInPandasWithStatePythonRunner( + chainedFunc, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + Array(argOffsets), + StructType.fromAttributes(dedupAttributes), + sessionLocalTimeZone, + pythonRunnerConf, + stateEncoder.asInstanceOf[ExpressionEncoder[Row]], + groupingAttributes.toStructType, + child.output.toStructType, + stateType, + conf.softLimitBytesPerBatchInApplyInPandasWithState, + conf.minDataCountForSampleInApplyInPandasWithState, + conf.softTimeoutMillisPurgeBatchInApplyInPandasWithState) + + val context = TaskContext.get() + + val processIter = iter.map { case (keyRow, stateData, valueIter) => + val groupedState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r }, + batchTimestampMs.getOrElse(NO_TIMESTAMP), + eventTimeWatermark.getOrElse(NO_TIMESTAMP), + timeoutConf, + hasTimedOut = hasTimedOut, + watermarkPresent).asInstanceOf[GroupStateImpl[Row]] + (keyRow, groupedState, valueIter) + } + runner.compute(processIter, context.partitionId(), context).flatMap { + case (stateIter, outputIter) => + // When the iterator is consumed, then write changes to state. + // state does not affect each others, hence when to update does not affect to the result. + def onIteratorCompletion: Unit = { + stateIter.foreach { case (keyRow, newGroupState, oldTimeoutTimestamp) => + if (newGroupState.isRemoved && !newGroupState.getTimeoutTimestampMs.isPresent()) { + stateManager.removeState(store, keyRow) + numRemovedStateRows += 1 + } else { + val currentTimeoutTimestamp = newGroupState.getTimeoutTimestampMs + .orElse(NO_TIMESTAMP) + val hasTimeoutChanged = currentTimeoutTimestamp != oldTimeoutTimestamp + val shouldWriteState = newGroupState.isUpdated || newGroupState.isRemoved || + hasTimeoutChanged + + if (shouldWriteState) { + val updatedStateObj = if (newGroupState.exists) newGroupState.get else null + stateManager.putState(store, keyRow, updatedStateObj, + currentTimeoutTimestamp) + numUpdatedStateRows += 1 + } + } + } + } + + CompletionIterator[InternalRow, Iterator[InternalRow]]( + outputIter, onIteratorCompletion).map { row => + numOutputRows += 1 + row + } + } + } + + override protected def callFunctionAndUpdateState( + stateData: StateData, + valueRowIter: Iterator[InternalRow], + hasTimedOut: Boolean): Iterator[InternalRow] = { + throw new UnsupportedOperationException("Should not reach here!") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala index 2da0000dad4e..4e98f86d969f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala @@ -88,9 +88,10 @@ private[python] object PandasGroupUtils { * argOffsets[argOffsets[0]+2 .. ] is the arg offsets for data attributes */ def resolveArgOffsets( - child: SparkPlan, groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = { + attributes: Seq[Attribute], + groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = { - val dataAttributes = child.output.drop(groupingAttributes.length) + val dataAttributes = attributes.drop(groupingAttributes.length) val groupingIndicesInData = groupingAttributes.map { attribute => dataAttributes.indexWhere(attribute.semanticEquals) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 3f369ac5e973..f386282a0b3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.execution.{LocalLimitExec, QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, MergingSessionsExec, ObjectHashAggregateExec, SortAggregateExec, UpdatingSessionsExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike +import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode @@ -62,6 +63,7 @@ class IncrementalExecution( StreamingJoinStrategy :: StatefulAggregationStrategy :: FlatMapGroupsWithStateStrategy :: + FlatMapGroupsInPandasWithStateStrategy :: StreamingRelationStrategy :: StreamingDeduplicationStrategy :: StreamingGlobalLimitStrategy(outputMode) :: Nil @@ -210,6 +212,13 @@ class IncrementalExecution( hasInitialState = hasInitialState ) + case m: FlatMapGroupsInPandasWithStateExec => + m.copy( + stateInfo = Some(nextStatefulOperationStateInfo), + batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), + eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs) + ) + case j: StreamingSymmetricHashJoinExec => j.copy( stateInfo = Some(nextStatefulOperationStateInfo), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 01ff72bac7bc..022fd1239ce4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -49,7 +49,7 @@ package object state { } /** Map each partition of an RDD along with data in a [[StateStore]]. */ - private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( + def mapPartitionsWithStateStore[U: ClassTag]( stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, From 79ba311456a085a793dbae2b77140b4264c49387 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 15 Sep 2022 11:26:20 +0900 Subject: [PATCH 02/38] meta-commit to credit properly on co-authorship From 00009940bed54fe0b657400a3bdd5a1a7060c723 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 15 Sep 2022 16:44:50 +0900 Subject: [PATCH 03/38] replace SPARK-XXXXX with real JIRA ticket --- .../scala/org/apache/spark/sql/execution/SparkStrategies.scala | 1 + .../execution/python/FlatMapGroupsInPandasWithStateExec.scala | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 062eae128f61..c64a123e3a78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -813,6 +813,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { hasInitialState, planLater(initialState), planLater(child) ) :: Nil case _: FlatMapGroupsInPandasWithState => + // TODO(SPARK-40443): support applyInPandasWithState in batch query throw new UnsupportedOperationException( "applyInPandasWithState is unsupported in batch query. Use applyInPandas instead.") case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 098d97af71ca..a106f3770a99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -64,7 +64,7 @@ case class FlatMapGroupsInPandasWithStateExec( eventTimeWatermark: Option[Long], child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { - // TODO(SPARK-XXXXX): Add the support of initial state. + // TODO(SPARK-40444): Add the support of initial state. override protected val initialStateDeserializer: Expression = null override protected val initialStateGroupAttrs: Seq[Attribute] = null override protected val initialStateDataAttrs: Seq[Attribute] = null From caa4924d6ed846799ebbf00bfa0e980519a3f86e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 16 Sep 2022 10:51:33 +0900 Subject: [PATCH 04/38] fix --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7f36417dfe0e..c8acb8ac09cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4575,7 +4575,7 @@ class SQLConf extends Serializable with Logging { def softTimeoutMillisPurgeBatchInApplyInPandasWithState: Long = getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH) - + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) From 76bb4043c9642a843159b14fe8aa2b3cc6b4ea92 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 16 Sep 2022 12:53:08 +0900 Subject: [PATCH 05/38] Add missing piece --- .../spark/sql/execution/arrow/ArrowWriter.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 34e128a4925f..2988c0fb5187 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -98,6 +98,16 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { count += 1 } + def sizeInBytes(): Int = { + var i = 0 + var bytes = 0 + while (i < fields.size) { + bytes += fields(i).getSizeInBytes() + i += 1 + } + bytes + } + def finish(): Unit = { root.setRowCount(count) fields.foreach(_.finish()) @@ -132,6 +142,10 @@ private[arrow] abstract class ArrowFieldWriter { count += 1 } + def getSizeInBytes(): Int = { + valueVector.getBufferSizeFor(count) + } + def finish(): Unit = { valueVector.setValueCount(count) } From a2f25e3d17a768f3e225db27a2fd1b050369469f Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 16 Sep 2022 15:03:55 +0900 Subject: [PATCH 06/38] unused import --- .../apache/spark/sql/execution/python/PandasGroupUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala index 4e98f86d969f..078876664062 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala @@ -24,7 +24,7 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.BasePythonRunner import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} -import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan} +import org.apache.spark.sql.execution.GroupedIterator import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** From bb3a80aff959793c520fe4049538f96076fe844d Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 16 Sep 2022 17:53:21 +0900 Subject: [PATCH 07/38] fix for Scala 2.13 --- .../sql/execution/python/ApplyInPandasWithStateWriter.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala index 7278ad410740..ed5af35a1373 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala @@ -52,8 +52,10 @@ class ApplyInPandasWithStateWriter( // We always produce at least one data row per grouping key whereas we only produce one // state metadata row per grouping key, so we only need to fill up the empty rows in // state metadata side. - private val arrowWriterForData = createArrowWriter(root.getFieldVectors.asScala.dropRight(1)) - private val arrowWriterForState = createArrowWriter(root.getFieldVectors.asScala.takeRight(1)) + private val arrowWriterForData = createArrowWriter( + root.getFieldVectors.asScala.toSeq.dropRight(1)) + private val arrowWriterForState = createArrowWriter( + root.getFieldVectors.asScala.toSeq.takeRight(1)) // We apply bin-packing the data from multiple groups into one Arrow RecordBatch to // gain the performance. In many cases, the amount of data per grouping key is quite From 4d3c7e9c7aa51cc7cf25b103124fc7703f3149c4 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 16 Sep 2022 21:38:38 +0900 Subject: [PATCH 08/38] reformat as suggested by linter --- .../pyspark/sql/pandas/_typing/__init__.pyi | 5 +- python/pyspark/sql/pandas/serializers.py | 135 ++++++++++++------ python/pyspark/worker.py | 21 ++- 3 files changed, 106 insertions(+), 55 deletions(-) diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi index 9f855c6c1151..01ae703d33f1 100644 --- a/python/pyspark/sql/pandas/_typing/__init__.pyi +++ b/python/pyspark/sql/pandas/_typing/__init__.pyi @@ -258,8 +258,9 @@ PandasGroupedMapFunction = Union[ Callable[[Any, DataFrameLike], DataFrameLike], ] -PandasGroupedMapFunctionWithState = Callable[[Any, Iterable[DataFrameLike], GroupStateImpl], - Iterable[DataFrameLike]] +PandasGroupedMapFunctionWithState = Callable[ + [Any, Iterable[DataFrameLike], GroupStateImpl], Iterable[DataFrameLike] +] class PandasVariadicGroupedAggFunction(Protocol): def __call__(self, *_: SeriesLike) -> LiteralType: ... diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index e561e8723819..b1f9c07d95d8 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -378,22 +378,31 @@ def load_stream(self, stream): class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): - - def __init__(self, timezone, safecheck, assign_cols_by_name, state_object_schema, - soft_limit_bytes_per_batch, min_data_count_for_sample, - soft_timeout_millis_purge_batch): + def __init__( + self, + timezone, + safecheck, + assign_cols_by_name, + state_object_schema, + soft_limit_bytes_per_batch, + min_data_count_for_sample, + soft_timeout_millis_purge_batch, + ): super(ApplyInPandasWithStateSerializer, self).__init__( - timezone, safecheck, assign_cols_by_name) + timezone, safecheck, assign_cols_by_name + ) self.pickleSer = CPickleSerializer() self.utf8_deserializer = UTF8Deserializer() self.state_object_schema = state_object_schema - self.result_state_df_type = StructType([ - StructField('properties', StringType()), - StructField('keyRowAsUnsafe', BinaryType()), - StructField('object', BinaryType()), - StructField('oldTimeoutTimestamp', LongType()), - ]) + self.result_state_df_type = StructType( + [ + StructField("properties", StringType()), + StructField("keyRowAsUnsafe", BinaryType()), + StructField("object", BinaryType()), + StructField("oldTimeoutTimestamp", LongType()), + ] + ) self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type) self.soft_limit_bytes_per_batch = soft_limit_bytes_per_batch @@ -412,14 +421,23 @@ def gen_data_and_state(batches): for batch in batches: batch_schema = batch.schema data_schema = pa.schema([batch_schema[i] for i in range(0, len(batch_schema) - 1)]) - state_schema = pa.schema([batch_schema[-1], ]) + state_schema = pa.schema( + [ + batch_schema[-1], + ] + ) batch_columns = batch.columns data_columns = batch_columns[0:-1] state_column = batch_columns[-1] data_batch = pa.RecordBatch.from_arrays(data_columns, schema=data_schema) - state_batch = pa.RecordBatch.from_arrays([state_column, ], schema=state_schema) + state_batch = pa.RecordBatch.from_arrays( + [ + state_column, + ], + schema=state_schema, + ) state_arrow = pa.Table.from_batches([state_batch]).itercolumns() state_pandas = [self.arrow_to_pandas(c) for c in state_arrow][0] @@ -431,13 +449,13 @@ def gen_data_and_state(batches): # no more data with grouping key + state break - state_info_col_properties = state_info_col['properties'] - state_info_col_key_row = state_info_col['keyRowAsUnsafe'] - state_info_col_object = state_info_col['object'] + state_info_col_properties = state_info_col["properties"] + state_info_col_key_row = state_info_col["keyRowAsUnsafe"] + state_info_col_object = state_info_col["object"] - data_start_offset = state_info_col['startOffset'] - num_data_rows = state_info_col['numRows'] - is_last_chunk = state_info_col['isLastChunk'] + data_start_offset = state_info_col["startOffset"] + num_data_rows = state_info_col["numRows"] + is_last_chunk = state_info_col["isLastChunk"] state_properties = json.loads(state_info_col_properties) if state_info_col_object: @@ -452,9 +470,11 @@ def gen_data_and_state(batches): state = state_for_current_group else: # there is no state being stored for same group, construct one - state = GroupStateImpl(keyAsUnsafe=state_info_col_key_row, - valueSchema=self.state_object_schema, - **state_properties) + state = GroupStateImpl( + keyAsUnsafe=state_info_col_key_row, + valueSchema=self.state_object_schema, + **state_properties, + ) if is_last_chunk: # discard the state being cached for same group @@ -470,7 +490,10 @@ def gen_data_and_state(batches): data_pandas = [self.arrow_to_pandas(c) for c in data_arrow] # state info - yield (data_pandas, state, ) + yield ( + data_pandas, + state, + ) batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) @@ -478,7 +501,10 @@ def gen_data_and_state(batches): # state will be same object for same grouping key for state, data in groupby(data_state_generator, key=lambda x: x[1]): - yield (data, state, ) + yield ( + data, + state, + ) def dump_stream(self, iterator, stream): """ @@ -503,10 +529,12 @@ def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_dat empty_rows_pdf = pd.DataFrame( dict.fromkeys(pa.schema(pdf_schema).names), - index=[x for x in range(0, empty_row_cnt_in_data)]) + index=[x for x in range(0, empty_row_cnt_in_data)], + ) empty_rows_state = pd.DataFrame( - columns=['properties', 'keyRowAsUnsafe', 'object', 'oldTimeoutTimestamp'], - index=[x for x in range(0, empty_row_cnt_in_state)]) + columns=["properties", "keyRowAsUnsafe", "object", "oldTimeoutTimestamp"], + index=[x for x in range(0, empty_row_cnt_in_state)], + ) pdfs.append(empty_rows_pdf) state_pdfs.append(empty_rows_state) @@ -514,9 +542,9 @@ def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_dat merged_pdf = pd.concat(pdfs, ignore_index=True) merged_state_pdf = pd.concat(state_pdfs, ignore_index=True) - return self._create_batch([ - (merged_pdf, pdf_schema), - (merged_state_pdf, self.result_state_pdf_arrow_type)]) + return self._create_batch( + [(merged_pdf, pdf_schema), (merged_state_pdf, self.result_state_pdf_arrow_type)] + ) def init_stream_yield_batches(): import pandas as pd @@ -547,19 +575,23 @@ def init_stream_yield_batches(): pdf_data_cnt += len(pdf) pdfs.append(pdf) - if sampled_data_size_per_row == 0 and \ - pdf_data_cnt > self.min_data_count_for_sample: + if ( + sampled_data_size_per_row == 0 + and pdf_data_cnt > self.min_data_count_for_sample + ): memory_usages = [p.memory_usage(deep=True).sum() for p in pdfs] sampled_data_size_per_row = sum(memory_usages) / pdf_data_cnt # This effectively works after the sampling has completed, size we multiply # by 0 if the sampling is still in progress. - batch_over_limit_on_size = (sampled_data_size_per_row * pdf_data_cnt) >= \ - self.soft_limit_bytes_per_batch + batch_over_limit_on_size = ( + sampled_data_size_per_row * pdf_data_cnt + ) >= self.soft_limit_bytes_per_batch if batch_over_limit_on_size: - batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, - state_pdfs, state_data_cnt) + batch = construct_record_batch( + pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt + ) pdfs = [] state_pdfs = [] @@ -580,10 +612,18 @@ def init_stream_yield_batches(): state_old_timeout_timestamp = state.oldTimeoutTimestamp state_dict = { - 'properties': [state_properties, ], - 'keyRowAsUnsafe': [state_key_row_as_binary, ], - 'object': [state_object, ], - 'oldTimeoutTimestamp': [state_old_timeout_timestamp, ], + "properties": [ + state_properties, + ], + "keyRowAsUnsafe": [ + state_key_row_as_binary, + ], + "object": [ + state_object, + ], + "oldTimeoutTimestamp": [ + state_old_timeout_timestamp, + ], } state_pdf = pd.DataFrame.from_dict(state_dict) @@ -592,11 +632,13 @@ def init_stream_yield_batches(): state_data_cnt += 1 cur_time_ns = time.time_ns() - is_timed_out_on_purge = ((cur_time_ns - last_purged_time_ns) // 1000000) >= \ - self.soft_timeout_millis_purge_batch + is_timed_out_on_purge = ( + (cur_time_ns - last_purged_time_ns) // 1000000 + ) >= self.soft_timeout_millis_purge_batch if is_timed_out_on_purge: - batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, - state_pdfs, state_data_cnt) + batch = construct_record_batch( + pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt + ) pdfs = [] state_pdfs = [] @@ -612,8 +654,9 @@ def init_stream_yield_batches(): # end of loop, we may have remaining data if pdf_data_cnt > 0 or state_data_cnt > 0: - batch = construct_record_batch(pdfs, pdf_data_cnt, return_schema, - state_pdfs, state_data_cnt) + batch = construct_record_batch( + pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt + ) if should_write_start_length: write_int(SpecialLengths.START_ARROW_STREAM, stream) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 59d2f5b9d61e..2c222359bec0 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -217,7 +217,9 @@ def wrapped(key_series, value_series_gen, state): if state.hasTimedOut: # Timeout processing pass empty iterator. Here we return an empty DataFrame instead. - values = [pd.DataFrame(columns=pd.concat(next(value_series_gen), axis=1).columns), ] + values = [ + pd.DataFrame(columns=pd.concat(next(value_series_gen), axis=1).columns), + ] else: values = (pd.concat(x, axis=1) for x in value_series_gen) @@ -232,8 +234,7 @@ def verify_element(result): # the number of columns of result have to match the return type # but it is fine for result to have no columns at all if it is empty if not ( - len(result.columns) == len(return_type) or - len(result.columns) == 0 and result.empty + len(result.columns) == len(return_type) or len(result.columns) == 0 and result.empty ): raise RuntimeError( "Number of columns of the element (pandas.DataFrame) in return iterator " @@ -259,7 +260,10 @@ def verify_element(result): result_iter_with_validation = (verify_element(x) for x in result_iter) - return (result_iter_with_validation, state, ) + return ( + result_iter_with_validation, + state, + ) return lambda k, v, s: [(wrapped(k, v, s), to_arrow_type(return_type))] @@ -428,7 +432,7 @@ def read_udfs(pickleSer, infile, eval_type): elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: soft_limit_bytes_per_batch = runner_conf.get( "spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch", - (64 * 1024 * 1024) + (64 * 1024 * 1024), ) soft_limit_bytes_per_batch = int(soft_limit_bytes_per_batch) @@ -443,11 +447,14 @@ def read_udfs(pickleSer, infile, eval_type): soft_timeout_millis_purge_batch = int(soft_timeout_millis_purge_batch) ser = ApplyInPandasWithStateSerializer( - timezone, safecheck, assign_cols_by_name, + timezone, + safecheck, + assign_cols_by_name, state_object_schema, soft_limit_bytes_per_batch, min_data_count_for_sample, - soft_timeout_millis_purge_batch) + soft_timeout_millis_purge_batch, + ) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: ser = ArrowStreamUDFSerializer() else: From b93e488c6e4e522f3e529cf025e47d5617605cba Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 19 Sep 2022 13:21:28 +0900 Subject: [PATCH 09/38] Update sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala Co-authored-by: Hyukjin Kwon --- .../python/FlatMapGroupsInPandasWithStateExec.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index a106f3770a99..6c5988037455 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -65,10 +65,10 @@ case class FlatMapGroupsInPandasWithStateExec( child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { // TODO(SPARK-40444): Add the support of initial state. - override protected val initialStateDeserializer: Expression = null - override protected val initialStateGroupAttrs: Seq[Attribute] = null - override protected val initialStateDataAttrs: Seq[Attribute] = null - override protected val initialState: SparkPlan = null + override protected val initialStateDeserializer: Expression = _ + override protected val initialStateGroupAttrs: Seq[Attribute] = _ + override protected val initialStateDataAttrs: Seq[Attribute] = _ + override protected val initialState: SparkPlan = _ override protected val hasInitialState: Boolean = false override protected val stateEncoder: ExpressionEncoder[Any] = From 5d23a6df8e78e95b66b40f5f0ea2cfa406efc3bc Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 19 Sep 2022 13:34:21 +0900 Subject: [PATCH 10/38] Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala Co-authored-by: Hyukjin Kwon --- .../sql/catalyst/analysis/UnsupportedOperationChecker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 99ba3802097b..84795203fd17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -148,7 +148,7 @@ object UnsupportedOperationChecker extends Logging { } // Disallow multiple `applyInPandasWithState`s. - if (applyInPandasWithStates.size >= 2) { + if (applyInPandasWithStates.size > 1) { throwError( "Multiple applyInPandasWithStates are not supported on a streaming " + "DataFrames/Datasets")(plan) From e757be0171307f6564ee6e653cf2f5233ac853b1 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 19 Sep 2022 13:36:20 +0900 Subject: [PATCH 11/38] 1st reflection of feedbacks --- python/pyspark/sql/pandas/group_ops.py | 36 ++++++++++++-------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index f845f2466011..8979cf77fedc 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -231,13 +231,11 @@ def applyInPandasWithState( per-group state. The result Dataset will represent the flattened record returned by the function. - For a streaming Dataset, the function will be invoked for each group repeatedly in every - trigger, and updates to each group's state will be saved across invocations. The function - will also be invoked for each timed-out state repeatedly. The sequence of the invocation - will be input data -> state timeout. When the function is invoked for state timeout, there - will be no data being presented. + For a streaming Dataset, the function will be invoked first for all input groups and then + for all timed out states where the input data is set to be empty. Updates to each group's + state will be saved across invocations. - The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and + The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as :class:`pyspark.sql.streaming.state.GroupStateImpl`. @@ -249,29 +247,27 @@ def applyInPandasWithState( elements in the iterator. The `outputStructType` should be a :class:`StructType` describing the schema of all - elements in returned value, `pandas.DataFrame`. The column labels of all elements in - returned value, `pandas.DataFrame` must either match the field names in the defined - schema if specified as strings, or match the field data types by position if not strings, + elements in the returned value, `pandas.DataFrame`. The column labels of all elements in + returned `pandas.DataFrame` must either match the field names in the defined schema if + specified as strings, or match the field data types by position if not strings, e.g. integer indices. - The `stateStructType` should be :class:`StructType` describing the schema of user-defined - state. The value of state will be presented as a tuple, as well as the update should be - performed with the tuple. User defined types e.g. native Python class types are not - supported. Alternatively, you can pickle the data and produce the data as BinaryType, but - it is tied to the backward and forward compatibility of pickle in Python, and Spark itself - does not guarantee the compatibility. + The `stateStructType` should be :class:`StructType` describing the schema of the + user-defined state. The value of the state will be presented as a tuple, as well as the + update should be performed with the tuple. User defined types e.g. native Python class + types are not supported. - The length of each element in both input and returned value, `pandas.DataFrame`, can be - arbitrary. The length of iterator in both input and returned value can be also arbitrary. + The size of each DataFrame in both the input and output can be arbitrary. The number of + DataFrames in both the input and output can also be arbitrary. .. versionadded:: 3.4.0 Parameters ---------- func : function - a Python native function to be called on every group. It should takes parameters - (key, Iterator[`pandas.DataFrame`], state) and returns Iterator[`pandas.DataFrame`]. - Note that the type of key is tuple, and the type of state is + a Python native function to be called on every group. It should take parameters + (key, Iterator[`pandas.DataFrame`], state) and return Iterator[`pandas.DataFrame`]. + Note that the type of the key is tuple and the type of the state is :class:`pyspark.sql.streaming.state.GroupStateImpl`. outputStructType : :class:`pyspark.sql.types.DataType` or str the type of the output records. The value can be either a From 43929ac1629889af5dd9c471fa0d7f1c865548ad Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 19 Sep 2022 14:08:29 +0900 Subject: [PATCH 12/38] reflect suggestion --- .../execution/python/ApplyInPandasWithStatePythonRunner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index 213c9f4e712b..ec9971fc613e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -61,7 +61,7 @@ class ApplyInPandasWithStatePythonRunner( with PythonArrowInput[InType] with PythonArrowOutput[OutType] { - override protected val schema: StructType = inputSchema.add("!__state__!", STATE_METADATA_SCHEMA) + override protected val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA) override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback From e7259955968613d245d21df22da5663f36da57a4 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 19 Sep 2022 18:55:52 +0900 Subject: [PATCH 13/38] WIP still documenting... --- python/pyspark/sql/pandas/group_ops.py | 2 + python/pyspark/sql/pandas/serializers.py | 66 +++++++++++++++++++ .../ApplyInPandasWithStatePythonRunner.scala | 6 +- .../python/ApplyInPandasWithStateWriter.scala | 46 ++++++++++--- .../execution/python/PythonArrowInput.scala | 1 - 5 files changed, 109 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 8979cf77fedc..b239ecacb9d7 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -280,6 +280,8 @@ def applyInPandasWithState( timeoutConf : str timeout configuration for groups that do not receive data for a while. valid values are defined in :class:`pyspark.sql.streaming.state.GroupStateTimeout`. + + # TODO: Examples """ from pyspark.sql import GroupedData diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index b1f9c07d95d8..2a559cde9bc6 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -378,6 +378,28 @@ def load_stream(self, stream): class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): + """ + Serializer used by Python worker to evaluate UDF for applyInPandasWithState. + + Parameters + ---------- + timezone : str + A timezone to respect when handling timestamp values + safecheck : bool + If True, conversion from Arrow to Pandas checks for overflow/truncation + assign_cols_by_name : bool + If True, then Pandas DataFrames will get columns by name + state_object_schema : StructType + The type of state object represented as Spark SQL type + soft_limit_bytes_per_batch : int + Soft limit of the accumulated size of records that can be written to a single + ArrowRecordBatch in memory. + min_data_count_for_sample : int + The minimum number of records to sample the size of record. + soft_timeout_millis_purge_batch : int + The soft timeout to force purging the ArrowRecordBatch regardless of the size. + """ + def __init__( self, timezone, @@ -410,12 +432,50 @@ def __init__( self.soft_timeout_millis_purge_batch = soft_timeout_millis_purge_batch def load_stream(self, stream): + """ + Read ArrowRecordBatches from stream, deserialize them to populate a list of pair + (data chunk, state), and convert the data into a list of pandas.Series. + + Please refer the doc of inner function `gen_data_and_state` for more details how + this function works in overall. + + In addition, this function further groups the return of `gen_data_and_state` by the state + instance (same semantic as grouping by grouping key) and produces an iterator of data + chunks for each group, so that the caller can lazily materialize the data chunk. + """ + import pyarrow as pa import json from itertools import groupby from pyspark.sql.streaming.state import GroupStateImpl def gen_data_and_state(batches): + """ + Deserialize ArrowRecordBatches and return a generator of + `(a list of pandas.Series, state)`. + + The logic on deserialization is following: + + 1. Read the entire data part from Arrow RecordBatch. + 2. Read the entire state information part from Arrow RecordBatch. + 3. Loop through each state information: + 3.A. Extract the data out from entire data via the information of data range. + 3.B. Construct a new state instance if the state information is the first occurrence + for the current grouping key. + 3.C. Leverage existing new state instance if the state instance is already available + for the current grouping key. (Meaning it's not the first occurrence.) + 3.D. Remove the cache of state instance if the state information denotes the data is + the last chunk for current grouping key. + + This deserialization logic assumes that Arrow RecordBatches contain the data with the + ordering that data chunks for same grouping key will appear sequentially. + + This function must avoid materializing multiple Arrow RecordBatches into memory at the + same time. And data chunks from the same grouping key should appear sequentially, to + further group them based on state instance (same state instance will be produced for + same grouping key). + """ + state_for_current_group = None for batch in batches: @@ -508,6 +568,7 @@ def gen_data_and_state(batches): def dump_stream(self, iterator, stream): """ + # TODO: document Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. This should be sent after creating the first record batch so in case of an error, it can be sent back to the JVM before the Arrow stream starts. @@ -515,6 +576,7 @@ def dump_stream(self, iterator, stream): def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_data_cnt): """ + # TODO: document Arrow RecordBatch requires all columns to have all same number of rows. Insert empty data for state/data with less elements to compensate. """ @@ -547,6 +609,10 @@ def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_dat ) def init_stream_yield_batches(): + """ + # TODO: document + :return: + """ import pandas as pd should_write_start_length = True diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index ec9971fc613e..54330629e9f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -41,7 +41,8 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** - * [[ArrowPythonRunner]] with [[org.apache.spark.sql.streaming.GroupState]]. + * A variant implementation of [[ArrowPythonRunner]] to serve the operation + * applyInPandasWithState. */ class ApplyInPandasWithStatePythonRunner( funcs: Seq[ChainedPythonFunctions], @@ -71,6 +72,9 @@ class ApplyInPandasWithStatePythonRunner( "Pandas execution requires more than 4 bytes. Please set higher buffer. " + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") + // applyInPandasWithState has its own mechanism to construct the Arrow RecordBatch instance. + // Configurations are both applied to executor and Python worker, set them to the worker conf + // to let Python worker read the config properly. override protected val workerConf: Map[String, String] = initialWorkerConf + (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH.key -> softLimitBytesPerBatch.toString) + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala index ed5af35a1373..4161aa88083e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala @@ -32,7 +32,20 @@ import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.types.{BinaryType, BooleanType, IntegerType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String - +/** + * This class abstracts the complexity on constructing Arrow RecordBatches for data and state with + * bin-packing and chunking. The caller only need to call the proper public methods of this class + * `startNewGroup`, `writeRow`, `finalizeGroup`, `finalizeData` and this class will write the data + * and state into Arrow RecordBatches with performing bin-pack and chunk internally. + * + * This class requires that the parameter `root` has initialized with the Arrow schema like below: + * - data fields + * - state field + * - nested schema (Refer ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA) + * + * Please refer the code comment in the implementation to see how the writes of data and state + * against Arrow RecordBatch work with consideration of bin-packing and chunking. + */ class ApplyInPandasWithStateWriter( root: VectorSchemaRoot, writer: ArrowStreamWriter, @@ -42,12 +55,12 @@ class ApplyInPandasWithStateWriter( import ApplyInPandasWithStateWriter._ - // We logically group the columns by family and initialize writer separately, since it's - // lot more easier and probably performant to write the row directly rather than + // We logically group the columns by family (data vs state) and initialize writer separately, + // since it's lot more easier and probably performant to write the row directly rather than // projecting the row to match up with the overall schema. // - // The number of data rows and state metadata rows can be different which seems to matter - // for Arrow RecordBatch, so we append empty rows to cover it. + // The number of data rows and state metadata rows can be different which could be problematic + // for Arrow RecordBatch, so we append empty rows to ensure both have the same number of rows. // // We always produce at least one data row per grouping key whereas we only produce one // state metadata row per grouping key, so we only need to fill up the empty rows in @@ -57,6 +70,8 @@ class ApplyInPandasWithStateWriter( private val arrowWriterForState = createArrowWriter( root.getFieldVectors.asScala.toSeq.takeRight(1)) + // Bin-packing + // // We apply bin-packing the data from multiple groups into one Arrow RecordBatch to // gain the performance. In many cases, the amount of data per grouping key is quite // small, which does not seem to maximize the benefits of using Arrow. @@ -66,9 +81,18 @@ class ApplyInPandasWithStateWriter( // the range of data and give a view, say, "zero-copy". To help splitting the range for // data, we provide the "start offset" and the "number of data" in the state metadata. // - // Pretty sure we don't bin-pack all groups into a single record batch. We have a soft - // limit on the size - it's not a hard limit since we allow current group to write all - // data even it's going to exceed the limit. + // We don't bin-pack all groups into a single record batch - we have a soft limit on the size + // of Arrow RecordBatch to stop adding next group. + // + // Chunking + // + // We also chunk the data from single group into multiple Arrow RecordBatch to ensure + // scalability. Note that we don't know the volume (number of rows, overall size) of data for + // specific group key before we read the entire data. The easiest approach to address both + // bin-pack and chunk is to check the size of the current Arrow RecordBatch per each write of + // row. + // + // How to measure the size of data? // // We perform some basic sampling for data to guess the size of the data very roughly, // and simply multiply by the number of data to estimate the size. We extract the size of @@ -76,12 +100,14 @@ class ApplyInPandasWithStateWriter( // UnsafeRow once we write to the record batch. If there is a memory bound here, it // should come from record batch. // + // Why we also perform timeout on batching Arrow RecordBatch? + // // In the meanwhile, we don't also want to let the current record batch collect the data // indefinitely, since we are pipelining the process between executor and python worker. // Python worker won't process any data if executor is not yet finalized a record // batch, which defeats the purpose of pipelining. To address this, we also introduce - // timeout for constructing a record batch. This is a soft limit indeed as same as limit - // on the size - we allow current group to write all data even it's timed-out. + // timeout for constructing a record batch. This is a soft limit - we allow current group + // to write all data even it's timed-out. private var numRowsForCurGroup = 0 private var startOffsetForCurGroup = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 6168d0f867ad..bf66791183ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -76,7 +76,6 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => val root = VectorSchemaRoot.create(arrowSchema, allocator) Utils.tryWithSafeFinally { - val arrowWriter = ArrowWriter.create(root) val writer = new ArrowStreamWriter(root, null, dataOut) writer.start() From 9c6bf60050f58773c99f0c50737b6de1886d0a5e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 19 Sep 2022 18:56:26 +0900 Subject: [PATCH 14/38] Rename GroupStateImpl to GroupState in PySpark (No additional interface) --- python/pyspark/sql/pandas/_typing/__init__.pyi | 4 ++-- python/pyspark/sql/pandas/serializers.py | 4 ++-- python/pyspark/sql/streaming/state.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi index 01ae703d33f1..acca8c00f2aa 100644 --- a/python/pyspark/sql/pandas/_typing/__init__.pyi +++ b/python/pyspark/sql/pandas/_typing/__init__.pyi @@ -30,7 +30,7 @@ from typing_extensions import Protocol, Literal from types import FunctionType from pyspark.sql._typing import LiteralType -from pyspark.sql.streaming.state import GroupStateImpl +from pyspark.sql.streaming.state import GroupState from pandas.core.frame import DataFrame as PandasDataFrame from pandas.core.series import Series as PandasSeries from numpy import ndarray as NDArray @@ -259,7 +259,7 @@ PandasGroupedMapFunction = Union[ ] PandasGroupedMapFunctionWithState = Callable[ - [Any, Iterable[DataFrameLike], GroupStateImpl], Iterable[DataFrameLike] + [Any, Iterable[DataFrameLike], GroupState], Iterable[DataFrameLike] ] class PandasVariadicGroupedAggFunction(Protocol): diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 2a559cde9bc6..ac134338af75 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -447,7 +447,7 @@ def load_stream(self, stream): import pyarrow as pa import json from itertools import groupby - from pyspark.sql.streaming.state import GroupStateImpl + from pyspark.sql.streaming.state import GroupState def gen_data_and_state(batches): """ @@ -530,7 +530,7 @@ def gen_data_and_state(batches): state = state_for_current_group else: # there is no state being stored for same group, construct one - state = GroupStateImpl( + state = GroupState( keyAsUnsafe=state_info_col_key_row, valueSchema=self.state_object_schema, **state_properties, diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py index 842eff322330..61625eca5d09 100644 --- a/python/pyspark/sql/streaming/state.py +++ b/python/pyspark/sql/streaming/state.py @@ -20,7 +20,7 @@ from pyspark.sql.types import DateType, Row, StructType -__all__ = ["GroupStateImpl", "GroupStateTimeout"] +__all__ = ["GroupState", "GroupStateTimeout"] class GroupStateTimeout: @@ -29,7 +29,7 @@ class GroupStateTimeout: EventTimeTimeout: str = "EventTimeTimeout" -class GroupStateImpl: +class GroupState: NO_TIMESTAMP: int = -1 def __init__( @@ -146,7 +146,7 @@ def setTimeoutTimestamp(self, timestampMs: int) -> None: raise ValueError("Timeout timestamp must be positive") if ( - self._event_time_watermark_ms != GroupStateImpl.NO_TIMESTAMP + self._event_time_watermark_ms != GroupState.NO_TIMESTAMP and timestampMs < self._event_time_watermark_ms ): raise ValueError( From 69bb3e871396880ff2190efc646df1510486e86d Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 19 Sep 2022 21:05:15 +0900 Subject: [PATCH 15/38] WIP still updating the doc --- python/pyspark/sql/streaming/state.py | 48 +++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py index 61625eca5d09..bbadb253f352 100644 --- a/python/pyspark/sql/streaming/state.py +++ b/python/pyspark/sql/streaming/state.py @@ -24,12 +24,19 @@ class GroupStateTimeout: + """ + Represents the type of timeouts possible for the Dataset operations applyInPandasWithState. + """ NoTimeout: str = "NoTimeout" ProcessingTimeTimeout: str = "ProcessingTimeTimeout" EventTimeTimeout: str = "EventTimeTimeout" class GroupState: + """ + Wrapper class for interacting with per-group state data in `applyInPandasWithState`. + """ + NO_TIMESTAMP: int = -1 def __init__( @@ -76,10 +83,16 @@ def __init__( @property def exists(self) -> bool: + """ + Whether state exists or not. + """ return self._defined @property def get(self) -> Tuple: + """ + Get the state value if it exists, or throw ValueError. + """ if self.exists: return tuple(self._value) else: @@ -87,6 +100,9 @@ def get(self) -> Tuple: @property def getOption(self) -> Optional[Tuple]: + """ + Get the state value if it exists, or return None. + """ if self.exists: return tuple(self._value) else: @@ -94,6 +110,10 @@ def getOption(self) -> Optional[Tuple]: @property def hasTimedOut(self) -> bool: + """ + Whether the function has been called because the key has timed out. + This can return true only when timeouts are enabled. + """ return self._has_timed_out # NOTE: this function is only available to PySpark implementation due to underlying @@ -103,6 +123,9 @@ def oldTimeoutTimestamp(self) -> int: return self._old_timeout_timestamp def update(self, newValue: Tuple) -> None: + """ + Update the value of the state. The value of the state cannot be null. + """ if newValue is None: raise ValueError("'None' is not a valid state value") @@ -112,11 +135,18 @@ def update(self, newValue: Tuple) -> None: self._removed = False def remove(self) -> None: + """ + Remove this state. + """ self._defined = False self._updated = False self._removed = True def setTimeoutDuration(self, durationMs: int) -> None: + """ + Set the timeout duration in ms for this key. + Processing time timeout must be enabled. + """ if isinstance(durationMs, str): # TODO(SPARK-40437): Support string representation of durationMs. raise ValueError("durationMs should be int but get :%s" % type(durationMs)) @@ -133,6 +163,11 @@ def setTimeoutDuration(self, durationMs: int) -> None: # TODO(SPARK-40438): Implement additionalDuration parameter. def setTimeoutTimestamp(self, timestampMs: int) -> None: + """ + Set the timeout timestamp for this key as milliseconds in epoch time. + This timestamp cannot be older than the current watermark. + Event time timeout must be enabled. + """ if self._timeout_conf != GroupStateTimeout.EventTimeTimeout: raise RuntimeError( "Cannot set timeout duration without enabling processing time timeout in " @@ -157,6 +192,10 @@ def setTimeoutTimestamp(self, timestampMs: int) -> None: self._timeout_timestamp = timestampMs def getCurrentWatermarkMs(self) -> int: + """ + Get the current event time watermark as milliseconds in epoch time. + In a streaming query, this can be called only when watermark is set. + """ if not self._watermark_present: raise RuntimeError( "Cannot get event time watermark timestamp without setting watermark before " @@ -165,6 +204,11 @@ def getCurrentWatermarkMs(self) -> int: return self._event_time_watermark_ms def getCurrentProcessingTimeMs(self) -> int: + """ + Get the current processing time as milliseconds in epoch time. + In a streaming query, this will return a constant value throughout the duration of a + trigger, even if the trigger is re-executed. + """ return self._batch_processing_time_ms def __str__(self) -> str: @@ -174,6 +218,10 @@ def __str__(self) -> str: return "GroupState()" def json(self) -> str: + """ + Convert the internal values of instance into JSON. This is used to send out the update + from Python worker to executor. + """ return json.dumps( { # Constructor From 516fa4f2a41a1a9b879bc2f5a781774ba4738a57 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 19 Sep 2022 21:06:58 +0900 Subject: [PATCH 16/38] Revert "Update sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala" This reverts commit b93e488c6e4e522f3e529cf025e47d5617605cba. --- .../python/FlatMapGroupsInPandasWithStateExec.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 6c5988037455..a106f3770a99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -65,10 +65,10 @@ case class FlatMapGroupsInPandasWithStateExec( child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { // TODO(SPARK-40444): Add the support of initial state. - override protected val initialStateDeserializer: Expression = _ - override protected val initialStateGroupAttrs: Seq[Attribute] = _ - override protected val initialStateDataAttrs: Seq[Attribute] = _ - override protected val initialState: SparkPlan = _ + override protected val initialStateDeserializer: Expression = null + override protected val initialStateGroupAttrs: Seq[Attribute] = null + override protected val initialStateDataAttrs: Seq[Attribute] = null + override protected val initialState: SparkPlan = null override protected val hasInitialState: Boolean = false override protected val stateEncoder: ExpressionEncoder[Any] = From 3ded76381f3d13be6bcfb1c58f8fdadb5e2218ea Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 20 Sep 2022 09:17:03 +0900 Subject: [PATCH 17/38] fix style --- python/pyspark/sql/streaming/state.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py index bbadb253f352..66b225e1b10c 100644 --- a/python/pyspark/sql/streaming/state.py +++ b/python/pyspark/sql/streaming/state.py @@ -27,6 +27,7 @@ class GroupStateTimeout: """ Represents the type of timeouts possible for the Dataset operations applyInPandasWithState. """ + NoTimeout: str = "NoTimeout" ProcessingTimeTimeout: str = "ProcessingTimeTimeout" EventTimeTimeout: str = "EventTimeTimeout" From f00486f2e0c0e8ec680e9861b091e0293847e92e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 20 Sep 2022 11:40:23 +0900 Subject: [PATCH 18/38] WIP still documenting... --- python/pyspark/sql/pandas/group_ops.py | 4 ++-- .../execution/python/ApplyInPandasWithStatePythonRunner.scala | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index b239ecacb9d7..b5628d0d3613 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -238,7 +238,7 @@ def applyInPandasWithState( The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as - :class:`pyspark.sql.streaming.state.GroupStateImpl`. + :class:`pyspark.sql.streaming.state.GroupState`. For each group, all columns are passed together as `pandas.DataFrame` to the user-function, and the returned `pandas.DataFrame` across all invocations are combined as a @@ -268,7 +268,7 @@ def applyInPandasWithState( a Python native function to be called on every group. It should take parameters (key, Iterator[`pandas.DataFrame`], state) and return Iterator[`pandas.DataFrame`]. Note that the type of the key is tuple and the type of the state is - :class:`pyspark.sql.streaming.state.GroupStateImpl`. + :class:`pyspark.sql.streaming.state.GroupState`. outputStructType : :class:`pyspark.sql.types.DataType` or str the type of the output records. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index 54330629e9f3..89c77d23a9b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -43,6 +43,10 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** * A variant implementation of [[ArrowPythonRunner]] to serve the operation * applyInPandasWithState. + * + * Unlike normal ArrowPythonRunner which both input (executor to python worker) and output (python + * worker are InternalRow, applyInPandasWithState has side data (state information) in both input + * and output, which requires different struct on Arrow RecordBatch. */ class ApplyInPandasWithStatePythonRunner( funcs: Seq[ChainedPythonFunctions], From 2fb8da04b96e5f92982a643604a8988f5dc14ffa Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 20 Sep 2022 12:31:15 +0900 Subject: [PATCH 19/38] Change the condition of constructing Arrow RecordBatch to number of records --- python/pyspark/sql/pandas/serializers.py | 60 ++---------------- python/pyspark/worker.py | 22 ++----- .../apache/spark/sql/internal/SQLConf.scala | 47 -------------- .../ApplyInPandasWithStatePythonRunner.scala | 23 +++---- .../python/ApplyInPandasWithStateWriter.scala | 63 +++++-------------- .../FlatMapGroupsInPandasWithStateExec.scala | 5 +- 6 files changed, 35 insertions(+), 185 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index ac134338af75..74847f5241ec 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -19,8 +19,6 @@ Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details. """ -import time - from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import StringType, StructType, BinaryType, StructField, LongType @@ -391,13 +389,8 @@ class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): If True, then Pandas DataFrames will get columns by name state_object_schema : StructType The type of state object represented as Spark SQL type - soft_limit_bytes_per_batch : int - Soft limit of the accumulated size of records that can be written to a single - ArrowRecordBatch in memory. - min_data_count_for_sample : int - The minimum number of records to sample the size of record. - soft_timeout_millis_purge_batch : int - The soft timeout to force purging the ArrowRecordBatch regardless of the size. + arrow_max_records_per_batch : int + Limit of the number of records that can be written to a single ArrowRecordBatch in memory. """ def __init__( @@ -406,9 +399,7 @@ def __init__( safecheck, assign_cols_by_name, state_object_schema, - soft_limit_bytes_per_batch, - min_data_count_for_sample, - soft_timeout_millis_purge_batch, + arrow_max_records_per_batch, ): super(ApplyInPandasWithStateSerializer, self).__init__( timezone, safecheck, assign_cols_by_name @@ -427,9 +418,7 @@ def __init__( ) self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type) - self.soft_limit_bytes_per_batch = soft_limit_bytes_per_batch - self.min_data_count_for_sample = min_data_count_for_sample - self.soft_timeout_millis_purge_batch = soft_timeout_millis_purge_batch + self.arrow_max_records_per_batch = arrow_max_records_per_batch def load_stream(self, stream): """ @@ -624,10 +613,6 @@ def init_stream_yield_batches(): pdf_data_cnt = 0 state_data_cnt = 0 - sampled_data_size_per_row = 0 - - last_purged_time_ns = time.time_ns() - for data in iterator: packaged_result = data[0] @@ -641,20 +626,7 @@ def init_stream_yield_batches(): pdf_data_cnt += len(pdf) pdfs.append(pdf) - if ( - sampled_data_size_per_row == 0 - and pdf_data_cnt > self.min_data_count_for_sample - ): - memory_usages = [p.memory_usage(deep=True).sum() for p in pdfs] - sampled_data_size_per_row = sum(memory_usages) / pdf_data_cnt - - # This effectively works after the sampling has completed, size we multiply - # by 0 if the sampling is still in progress. - batch_over_limit_on_size = ( - sampled_data_size_per_row * pdf_data_cnt - ) >= self.soft_limit_bytes_per_batch - - if batch_over_limit_on_size: + if pdf_data_cnt > self.arrow_max_records_per_batch: batch = construct_record_batch( pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt ) @@ -663,7 +635,6 @@ def init_stream_yield_batches(): state_pdfs = [] pdf_data_cnt = 0 state_data_cnt = 0 - last_purged_time_ns = time.time_ns() if should_write_start_length: write_int(SpecialLengths.START_ARROW_STREAM, stream) @@ -697,27 +668,6 @@ def init_stream_yield_batches(): state_pdfs.append(state_pdf) state_data_cnt += 1 - cur_time_ns = time.time_ns() - is_timed_out_on_purge = ( - (cur_time_ns - last_purged_time_ns) // 1000000 - ) >= self.soft_timeout_millis_purge_batch - if is_timed_out_on_purge: - batch = construct_record_batch( - pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt - ) - - pdfs = [] - state_pdfs = [] - pdf_data_cnt = 0 - state_data_cnt = 0 - last_purged_time_ns = cur_time_ns - - if should_write_start_length: - write_int(SpecialLengths.START_ARROW_STREAM, stream) - should_write_start_length = False - - yield batch - # end of loop, we may have remaining data if pdf_data_cnt > 0 or state_data_cnt > 0: batch = construct_record_batch( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 2c222359bec0..98409a3df9e4 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -430,30 +430,18 @@ def read_udfs(pickleSer, infile, eval_type): if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: - soft_limit_bytes_per_batch = runner_conf.get( - "spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch", - (64 * 1024 * 1024), + arrow_max_records_per_batch = runner_conf.get( + "spark.sql.execution.arrow.maxRecordsPerBatch", + 10000 ) - soft_limit_bytes_per_batch = int(soft_limit_bytes_per_batch) - - min_data_count_for_sample = runner_conf.get( - "spark.sql.execution.applyInPandasWithState.minDataCountForSample", 100 - ) - min_data_count_for_sample = int(min_data_count_for_sample) - - soft_timeout_millis_purge_batch = runner_conf.get( - "spark.sql.execution.applyInPandasWithState.softTimeoutPurgeBatch", 100 - ) - soft_timeout_millis_purge_batch = int(soft_timeout_millis_purge_batch) + arrow_max_records_per_batch = int(arrow_max_records_per_batch) ser = ApplyInPandasWithStateSerializer( timezone, safecheck, assign_cols_by_name, state_object_schema, - soft_limit_bytes_per_batch, - min_data_count_for_sample, - soft_timeout_millis_purge_batch, + arrow_max_records_per_batch, ) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: ser = ArrowStreamUDFSerializer() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c8acb8ac09cd..de25c19a26eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2705,44 +2705,6 @@ object SQLConf { .booleanConf .createWithDefault(false) - val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH = - buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch") - .internal() - .doc("When using applyInPandasWithState, set a soft limit of the accumulated size of " + - "records that can be written to a single ArrowRecordBatch in memory. This is used to " + - "restrict the amount of memory being used to materialize the data in both executor and " + - "Python worker. The accumulated size of records are calculated via sampling a set of " + - "records. Splitting the ArrowRecordBatch is performed per record, so unless a record " + - "is quite huge, the size of constructed ArrowRecordBatch will be around the " + - "configured value.") - .version("3.4.0") - .bytesConf(ByteUnit.BYTE) - .createWithDefaultString("64MB") - - val MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE = - buildConf("spark.sql.execution.applyInPandasWithState.minDataCountForSample") - .internal() - .doc("When using applyInPandasWithState, specify the minimum number of records to sample " + - "the size of record. The size being retrieved from sampling will be used to estimate " + - "the accumulated size of records. Note that limiting by size does not work if the " + - "number of records are less than the configured value. For such case, ArrowRecordBatch " + - "will only be split for soft timeout.") - .version("3.4.0") - .intConf - .createWithDefault(100) - - val MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH = - buildConf("spark.sql.execution.applyInPandasWithState.softTimeoutPurgeBatch") - .internal() - .doc("When using applyInPandasWithState, specify the soft timeout for purging the " + - "ArrowRecordBatch. If batching records exceeds the timeout, Spark will force splitting " + - "the ArrowRecordBatch regardless of estimated size. This config ensures the receiver " + - "of data (both executor and Python worker) to not wait indefinitely for sender to " + - "complete the ArrowRecordBatch, which may hurt both throughput and latency.") - .version("3.4.0") - .timeConf(TimeUnit.MILLISECONDS) - .createWithDefaultString("100ms") - val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") .internal() .doc("When true, the apply function of the rule verifies whether the right node of the" + @@ -4567,15 +4529,6 @@ class SQLConf extends Serializable with Logging { def arrowSafeTypeConversion: Boolean = getConf(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION) - def softLimitBytesPerBatchInApplyInPandasWithState: Long = - getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH) - - def minDataCountForSampleInApplyInPandasWithState: Int = - getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE) - - def softTimeoutMillisPurgeBatchInApplyInPandasWithState: Long = - getConf(SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH) - def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index 89c77d23a9b6..7d1e005222c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -58,34 +58,30 @@ class ApplyInPandasWithStatePythonRunner( stateEncoder: ExpressionEncoder[Row], keySchema: StructType, valueSchema: StructType, - stateValueSchema: StructType, - softLimitBytesPerBatch: Long, - minDataCountForSample: Int, - softTimeoutMillsPurgeBatch: Long) + stateValueSchema: StructType) extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets) with PythonArrowInput[InType] with PythonArrowOutput[OutType] { + private val sqlConf = SQLConf.get + override protected val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA) - override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + override val simplifiedTraceback: Boolean = sqlConf.pysparkSimplifiedTraceback - override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize + override val bufferSize: Int = sqlConf.pandasUDFBufferSize require( bufferSize >= 4, "Pandas execution requires more than 4 bytes. Please set higher buffer. " + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") + private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch + // applyInPandasWithState has its own mechanism to construct the Arrow RecordBatch instance. // Configurations are both applied to executor and Python worker, set them to the worker conf // to let Python worker read the config properly. override protected val workerConf: Map[String, String] = initialWorkerConf + - (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH.key -> - softLimitBytesPerBatch.toString) + - (SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE.key -> - minDataCountForSample.toString) + - (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH.key -> - softTimeoutMillsPurgeBatch.toString) + (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) private val stateRowDeserializer = stateEncoder.createDeserializer() @@ -100,8 +96,7 @@ class ApplyInPandasWithStatePythonRunner( writer: ArrowStreamWriter, dataOut: DataOutputStream, inputIterator: Iterator[InType]): Unit = { - val w = new ApplyInPandasWithStateWriter(root, writer, softLimitBytesPerBatch, - minDataCountForSample, softTimeoutMillsPurgeBatch) + val w = new ApplyInPandasWithStateWriter(root, writer, arrowMaxRecordsPerBatch) while (inputIterator.hasNext) { val (keyRow, groupState, dataIter) = inputIterator.next() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala index 4161aa88083e..78e208fec38e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala @@ -49,9 +49,7 @@ import org.apache.spark.unsafe.types.UTF8String class ApplyInPandasWithStateWriter( root: VectorSchemaRoot, writer: ArrowStreamWriter, - softLimitBytesPerBatch: Long, - minDataCountForSample: Int, - softTimeoutMillsPurgeBatch: Long) { + arrowMaxRecordsPerBatch: Int) { import ApplyInPandasWithStateWriter._ @@ -70,7 +68,7 @@ class ApplyInPandasWithStateWriter( private val arrowWriterForState = createArrowWriter( root.getFieldVectors.asScala.toSeq.takeRight(1)) - // Bin-packing + // - Bin-packing // // We apply bin-packing the data from multiple groups into one Arrow RecordBatch to // gain the performance. In many cases, the amount of data per grouping key is quite @@ -81,42 +79,28 @@ class ApplyInPandasWithStateWriter( // the range of data and give a view, say, "zero-copy". To help splitting the range for // data, we provide the "start offset" and the "number of data" in the state metadata. // - // We don't bin-pack all groups into a single record batch - we have a soft limit on the size - // of Arrow RecordBatch to stop adding next group. + // We don't bin-pack all groups into a single record batch - we have a limit on the number + // of rows in the current Arrow RecordBatch to stop adding next group. // - // Chunking + // - Chunking // // We also chunk the data from single group into multiple Arrow RecordBatch to ensure // scalability. Note that we don't know the volume (number of rows, overall size) of data for // specific group key before we read the entire data. The easiest approach to address both - // bin-pack and chunk is to check the size of the current Arrow RecordBatch per each write of - // row. + // bin-pack and chunk is to check the number of rows in the current Arrow RecordBatch for each + // write of row. // - // How to measure the size of data? + // - Consideration // - // We perform some basic sampling for data to guess the size of the data very roughly, - // and simply multiply by the number of data to estimate the size. We extract the size of - // data from the record batch rather than UnsafeRow, as we don't hold the memory for - // UnsafeRow once we write to the record batch. If there is a memory bound here, it - // should come from record batch. - // - // Why we also perform timeout on batching Arrow RecordBatch? - // - // In the meanwhile, we don't also want to let the current record batch collect the data - // indefinitely, since we are pipelining the process between executor and python worker. - // Python worker won't process any data if executor is not yet finalized a record - // batch, which defeats the purpose of pipelining. To address this, we also introduce - // timeout for constructing a record batch. This is a soft limit - we allow current group - // to write all data even it's timed-out. + // Since the number of rows in Arrow RecordBatch does not represent the actual size (bytes), + // the limit should be set very conservatively. Using a small number of limit does not introduce + // correctness issues. private var numRowsForCurGroup = 0 private var startOffsetForCurGroup = 0 private var totalNumRowsForBatch = 0 private var totalNumStatesForBatch = 0 - private var sampledDataSizePerRow = 0 - private var lastBatchPurgedMillis = System.currentTimeMillis() - private var currentGroupKeyRow: UnsafeRow = _ private var currentGroupState: GroupStateImpl[Row] = _ @@ -126,21 +110,10 @@ class ApplyInPandasWithStateWriter( } def writeRow(dataRow: InternalRow): Unit = { - // Currently, this only works when the number of rows are greater than the minimum - // data count for sampling. And we technically have no way to pick some rows from - // record batch and measure the size of data, hence we leverage all data in current - // record batch. We only sample once as it could be costly. - if (sampledDataSizePerRow == 0 && totalNumRowsForBatch > minDataCountForSample) { - sampledDataSizePerRow = arrowWriterForData.sizeInBytes() / totalNumRowsForBatch - } - - // If it exceeds the condition of batch (only size, not about timeout) and - // there is more data for the same group, flush and construct a new batch. + // If it exceeds the condition of batch (number of records) and there is more data for the + // same group, finalize and construct a new batch. - // The soft-limit on size effectively works after the sampling has completed, since we - // multiply the number of rows by 0 if the sampling is still in progress. - - if (sampledDataSizePerRow * totalNumRowsForBatch >= softLimitBytesPerBatch) { + if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) { // Provide state metadata row as intermediate val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState, startOffsetForCurGroup, numRowsForCurGroup, isLastChunk = false) @@ -165,16 +138,11 @@ class ApplyInPandasWithStateWriter( // The start offset for next group would be same as the total number of rows for batch, // unless the next group starts with new batch. startOffsetForCurGroup = totalNumRowsForBatch - - // The soft-limit on timeout applies on finalization of each group. - if (System.currentTimeMillis() - lastBatchPurgedMillis > softTimeoutMillsPurgeBatch) { - finalizeCurrentArrowBatch() - } } def finalizeData(): Unit = { if (numRowsForCurGroup > 0) { - // We still have some rows in the current record batch. Need to flush them as well. + // We still have some rows in the current record batch. Need to finalize them as well. finalizeCurrentArrowBatch() } } @@ -224,7 +192,6 @@ class ApplyInPandasWithStateWriter( numRowsForCurGroup = 0 totalNumRowsForBatch = 0 totalNumStatesForBatch = 0 - lastBatchPurgedMillis = System.currentTimeMillis() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index a106f3770a99..4869902a19eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -156,10 +156,7 @@ case class FlatMapGroupsInPandasWithStateExec( stateEncoder.asInstanceOf[ExpressionEncoder[Row]], groupingAttributes.toStructType, child.output.toStructType, - stateType, - conf.softLimitBytesPerBatchInApplyInPandasWithState, - conf.minDataCountForSampleInApplyInPandasWithState, - conf.softTimeoutMillisPurgeBatchInApplyInPandasWithState) + stateType) val context = TaskContext.get() From 1a4f158f02d52f9880e867c711e17a4529f1426f Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 20 Sep 2022 13:05:50 +0900 Subject: [PATCH 20/38] small fix --- .../ApplyInPandasWithStatePythonRunner.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index 7d1e005222c3..1745269b9515 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -69,11 +69,17 @@ class ApplyInPandasWithStatePythonRunner( override val simplifiedTraceback: Boolean = sqlConf.pysparkSimplifiedTraceback - override val bufferSize: Int = sqlConf.pandasUDFBufferSize - require( - bufferSize >= 4, - "Pandas execution requires more than 4 bytes. Please set higher buffer. " + - s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") + override val bufferSize: Int = { + val configuredSize = sqlConf.pandasUDFBufferSize + if (configuredSize < 4) { + logWarning("Pandas execution requires more than 4 bytes. Please configure bigger value " + + s"for the configuration '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'. " + + "Force using the value '4'.") + 4 + } else { + configuredSize + } + } private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch From c5b35a4162262df0f8c72c705e5b5fb8f5093328 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 20 Sep 2022 15:42:35 +0900 Subject: [PATCH 21/38] more doc --- .../python/ApplyInPandasWithStateWriter.scala | 45 +++++++++++++++---- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala index 78e208fec38e..d9f256460984 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala @@ -38,7 +38,8 @@ import org.apache.spark.unsafe.types.UTF8String * `startNewGroup`, `writeRow`, `finalizeGroup`, `finalizeData` and this class will write the data * and state into Arrow RecordBatches with performing bin-pack and chunk internally. * - * This class requires that the parameter `root` has initialized with the Arrow schema like below: + * This class requires that the parameter `root` has been initialized with the Arrow schema like + * below: * - data fields * - state field * - nested schema (Refer ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA) @@ -53,16 +54,24 @@ class ApplyInPandasWithStateWriter( import ApplyInPandasWithStateWriter._ - // We logically group the columns by family (data vs state) and initialize writer separately, - // since it's lot more easier and probably performant to write the row directly rather than - // projecting the row to match up with the overall schema. + // Unlike applyInPandas (and other PySpark operators), applyInPandasWithState requires to produce + // the additional data `state`, along with the input data. // - // The number of data rows and state metadata rows can be different which could be problematic - // for Arrow RecordBatch, so we append empty rows to ensure both have the same number of rows. + // ArrowStreamWriter supports only single VectorSchemaRoot, which means all Arrow RecordBatches + // being sent out from ArrowStreamWriter should have same schema. That said, we have to construct + // "an" Arrow schema to contain both types of data, and also construct Arrow RecordBatches to + // contain both data. // - // We always produce at least one data row per grouping key whereas we only produce one - // state metadata row per grouping key, so we only need to fill up the empty rows in - // state metadata side. + // To achieve this, we extend the schema for input data to have a column for state at the end. + // But also, we logically group the columns by family (data vs state) and initialize writer + // separately, since it's lot more easier and probably performant to write the row directly + // rather than projecting the row to match up with the overall schema. + // + // Although Arrow RecordBatch enables to write the data as columnar, we figure out it gives + // strange outputs if we don't ensure that all columns have the same number of values. Since + // there are one or more data for a grouping key (applies to case of handling timed out state + // as well) whereas there is only one state for a grouping key, we have to fill up the empty rows + // in state side to ensure both have the same number of rows. private val arrowWriterForData = createArrowWriter( root.getFieldVectors.asScala.toSeq.dropRight(1)) private val arrowWriterForState = createArrowWriter( @@ -104,11 +113,22 @@ class ApplyInPandasWithStateWriter( private var currentGroupKeyRow: UnsafeRow = _ private var currentGroupState: GroupStateImpl[Row] = _ + /** + * Indicates writer to start with new grouping key. + * + * @param keyRow The grouping key row for current group. + * @param groupState The instance of GroupStateImpl for current group. + */ def startNewGroup(keyRow: UnsafeRow, groupState: GroupStateImpl[Row]): Unit = { currentGroupKeyRow = keyRow currentGroupState = groupState } + /** + * Indicates writer to write a row in the current group. + * + * @param dataRow The row to write in the current group. + */ def writeRow(dataRow: InternalRow): Unit = { // If it exceeds the condition of batch (number of records) and there is more data for the // same group, finalize and construct a new batch. @@ -128,6 +148,10 @@ class ApplyInPandasWithStateWriter( totalNumRowsForBatch += 1 } + /** + * Indicates writer that current group has finalized and there will be no further row bound to + * the current group. + */ def finalizeGroup(): Unit = { // Provide state metadata row val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState, @@ -140,6 +164,9 @@ class ApplyInPandasWithStateWriter( startOffsetForCurGroup = totalNumRowsForBatch } + /** + * Indicates writer that all groups have been processed. + */ def finalizeData(): Unit = { if (numRowsForCurGroup > 0) { // We still have some rows in the current record batch. Need to finalize them as well. From a95df289ee5673c95ded09f2122160340ae8d5a5 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 20 Sep 2022 17:34:53 +0900 Subject: [PATCH 22/38] still documenting... --- python/pyspark/sql/pandas/serializers.py | 136 +++++++++++------- python/pyspark/worker.py | 9 ++ .../ApplyInPandasWithStatePythonRunner.scala | 9 +- .../python/ApplyInPandasWithStateWriter.scala | 6 +- 4 files changed, 101 insertions(+), 59 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 74847f5241ec..9bff0a1013cd 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -557,21 +557,51 @@ def gen_data_and_state(batches): def dump_stream(self, iterator, stream): """ - # TODO: document - Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. - This should be sent after creating the first record batch so in case of an error, it can - be sent back to the JVM before the Arrow stream starts. + Read through an iterator of (iterator of pandas DataFrame, state), serialize them to Arrow + RecordBatches, and write batches to stream. """ - def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_data_cnt): + import pandas as pd + import pyarrow as pa + + def construct_state_pdf(state): """ - # TODO: document - Arrow RecordBatch requires all columns to have all same number of rows. - Insert empty data for state/data with less elements to compensate. + Construct a pandas DataFrame from the state instance. """ - import pandas as pd - import pyarrow as pa + state_properties = state.json().encode("utf-8") + state_key_row_as_binary = state._keyAsUnsafe + state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) + state_old_timeout_timestamp = state.oldTimeoutTimestamp + + state_dict = { + "properties": [ + state_properties, + ], + "keyRowAsUnsafe": [ + state_key_row_as_binary, + ], + "object": [ + state_object, + ], + "oldTimeoutTimestamp": [ + state_old_timeout_timestamp, + ], + } + + return pd.DataFrame.from_dict(state_dict) + + def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_data_cnt): + """ + Construct a new Arrow RecordBatch based on output pandas DataFrames and states. Each + one matches to the single struct field for Arrow schema, hence the return value of + Arrow RecordBatch will have schema with two fields, in `data`, `state` order. + (Readers are expected to access the field via position rather than the name. We do + not guarantee the name of the field.) + + Note that Arrow RecordBatch requires all columns to have all same number of rows, + hence this function inserts empty data for state/data with less elements to compensate. + """ max_data_cnt = max(pdf_data_cnt, state_data_cnt) @@ -597,86 +627,88 @@ def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_dat [(merged_pdf, pdf_schema), (merged_state_pdf, self.result_state_pdf_arrow_type)] ) - def init_stream_yield_batches(): + def serialize_batches(): """ - # TODO: document - :return: - """ - import pandas as pd - - should_write_start_length = True + Read through an iterator of (iterator of pandas DataFrame, state), and serialize them + to Arrow RecordBatches. + This function does batching on constructing the Arrow RecordBatch; a batch will be + serialized to the Arrow RecordBatch when the total number of records exceeds the + configured threshold. + """ + # a set of variables for the state of current batch which will be converted to Arrow + # RecordBatch. pdfs = [] state_pdfs = [] - return_schema = None - pdf_data_cnt = 0 state_data_cnt = 0 + return_schema = None + for data in iterator: + # data represents the result of each call of user function packaged_result = data[0] + # There are two results from the call of user function: + # 1) iterator of pandas DataFrame (output) + # 2) updated state instance pdf_iter = packaged_result[0][0] state = packaged_result[0][1] - # this won't change across batches + + # This is static and won't change across batches. return_schema = packaged_result[1] + state_pdf = construct_state_pdf(state) + + state_pdfs.append(state_pdf) + state_data_cnt += 1 + for pdf in pdf_iter: + # We ignore empty pandas DataFrame. if len(pdf) > 0: pdf_data_cnt += len(pdf) pdfs.append(pdf) + # If the total number of records in current batch exceeds the configured + # threshold, time to construct the Arrow RecordBatch from the batch. if pdf_data_cnt > self.arrow_max_records_per_batch: batch = construct_record_batch( pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt ) + # Reset the variables to start with new batch for further data. pdfs = [] state_pdfs = [] pdf_data_cnt = 0 state_data_cnt = 0 - if should_write_start_length: - write_int(SpecialLengths.START_ARROW_STREAM, stream) - should_write_start_length = False - yield batch - # pick up state for only last chunk as state should have been updated so far - state_properties = state.json().encode("utf-8") - state_key_row_as_binary = state._keyAsUnsafe - state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) - state_old_timeout_timestamp = state.oldTimeoutTimestamp - - state_dict = { - "properties": [ - state_properties, - ], - "keyRowAsUnsafe": [ - state_key_row_as_binary, - ], - "object": [ - state_object, - ], - "oldTimeoutTimestamp": [ - state_old_timeout_timestamp, - ], - } - - state_pdf = pd.DataFrame.from_dict(state_dict) - - state_pdfs.append(state_pdf) - state_data_cnt += 1 - - # end of loop, we may have remaining data + # processed all output, but current batch may not be flushed yet. if pdf_data_cnt > 0 or state_data_cnt > 0: batch = construct_record_batch( pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt ) + yield batch + + def init_stream_yield_batches(batches): + """ + This function helps to ensure the requirement for Pandas UDFs - Pandas UDFs require a + START_ARROW_STREAM before the Arrow stream is sent. + + START_ARROW_STREAM should be sent after creating the first record batch so in case of + an error, it can be sent back to the JVM before the Arrow stream starts. + """ + should_write_start_length = True + + for batch in batches: if should_write_start_length: write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False yield batch - return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) + batches_to_write = init_stream_yield_batches(serialize_batches()) + + return ArrowStreamSerializer.dump_stream(self, batches_to_write, stream) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 98409a3df9e4..daf78118361b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -210,7 +210,13 @@ def wrapped(key_series, value_series): def wrap_grouped_map_pandas_udf_with_state(f, return_type): + """ + # FIXME: document + """ def wrapped(key_series, value_series_gen, state): + """ + # FIXME: document + """ import pandas as pd key = tuple(s[0] for s in key_series) @@ -579,6 +585,9 @@ def mapper(a): parsed_offsets = extract_key_value_indexes(arg_offsets) def mapper(a): + """ + # FIXME: document + """ from itertools import tee state = a[1] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index 1745269b9515..ed381a84a56e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -44,9 +44,9 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} * A variant implementation of [[ArrowPythonRunner]] to serve the operation * applyInPandasWithState. * - * Unlike normal ArrowPythonRunner which both input (executor to python worker) and output (python - * worker are InternalRow, applyInPandasWithState has side data (state information) in both input - * and output, which requires different struct on Arrow RecordBatch. + * Unlike normal ArrowPythonRunner which both input and output (executor <-> python worker) + * are InternalRow, applyInPandasWithState has side data (state information) in both input + * and output along with data, which requires different struct on Arrow RecordBatch. */ class ApplyInPandasWithStatePythonRunner( funcs: Seq[ChainedPythonFunctions], @@ -166,7 +166,8 @@ class ApplyInPandasWithStatePythonRunner( // The entire row in record batch seems to be for data. None } else { - // NOTE: See StateReaderIterator.STATE_METADATA_SCHEMA for the schema. + // NOTE: See ApplyInPandasWithStatePythonRunner.STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER + // for the schema. val propertiesAsJson = parse(row.getUTF8String(0).toString) val keyRowAsUnsafeAsBinary = row.getBinary(1) val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala index d9f256460984..eed3c16be829 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala @@ -69,9 +69,9 @@ class ApplyInPandasWithStateWriter( // // Although Arrow RecordBatch enables to write the data as columnar, we figure out it gives // strange outputs if we don't ensure that all columns have the same number of values. Since - // there are one or more data for a grouping key (applies to case of handling timed out state - // as well) whereas there is only one state for a grouping key, we have to fill up the empty rows - // in state side to ensure both have the same number of rows. + // there are at least one data for a grouping key (we ensure this for the case of handling timed + // out state as well) whereas there is only one state for a grouping key, we have to fill up the + // empty rows in state side to ensure both have the same number of rows. private val arrowWriterForData = createArrowWriter( root.getFieldVectors.asScala.toSeq.dropRight(1)) private val arrowWriterForState = createArrowWriter( From 0fee506ba45c07c23e9ec19adae9066210da5952 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 20 Sep 2022 17:44:47 +0900 Subject: [PATCH 23/38] slight refactor --- python/pyspark/sql/pandas/serializers.py | 49 ++++++++++++++---------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 9bff0a1013cd..8aa27cb1b947 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -438,6 +438,28 @@ def load_stream(self, stream): from itertools import groupby from pyspark.sql.streaming.state import GroupState + def construct_state(state_info_col): + """ + Construct state instance from the value of state information column. + """ + + state_info_col_properties = state_info_col["properties"] + state_info_col_key_row = state_info_col["keyRowAsUnsafe"] + state_info_col_object = state_info_col["object"] + + state_properties = json.loads(state_info_col_properties) + if state_info_col_object: + state_object = self.pickleSer.loads(state_info_col_object) + else: + state_object = None + state_properties["optionalValue"] = state_object + + return GroupState( + keyAsUnsafe=state_info_col_key_row, + valueSchema=self.state_object_schema, + **state_properties, + ) + def gen_data_and_state(batches): """ Deserialize ArrowRecordBatches and return a generator of @@ -498,32 +520,17 @@ def gen_data_and_state(batches): # no more data with grouping key + state break - state_info_col_properties = state_info_col["properties"] - state_info_col_key_row = state_info_col["keyRowAsUnsafe"] - state_info_col_object = state_info_col["object"] - data_start_offset = state_info_col["startOffset"] num_data_rows = state_info_col["numRows"] is_last_chunk = state_info_col["isLastChunk"] - state_properties = json.loads(state_info_col_properties) - if state_info_col_object: - state_object = self.pickleSer.loads(state_info_col_object) - else: - state_object = None - state_properties["optionalValue"] = state_object - if state_for_current_group: # use the state, we already have state for same group and there should be # some data in same group being processed earlier state = state_for_current_group else: # there is no state being stored for same group, construct one - state = GroupState( - keyAsUnsafe=state_info_col_key_row, - valueSchema=self.state_object_schema, - **state_properties, - ) + state = construct_state(state_info_col) if is_last_chunk: # discard the state being cached for same group @@ -544,15 +551,15 @@ def gen_data_and_state(batches): state, ) - batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) - data_state_generator = gen_data_and_state(batches) + data_state_generator = gen_data_and_state(_batches) # state will be same object for same grouping key - for state, data in groupby(data_state_generator, key=lambda x: x[1]): + for _state, _data in groupby(data_state_generator, key=lambda x: x[1]): yield ( - data, - state, + _data, + _state, ) def dump_stream(self, iterator, stream): From 7051799669c1cd8ca40b3a77379367466a2ecee4 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 20 Sep 2022 20:12:00 +0900 Subject: [PATCH 24/38] further documentation --- python/pyspark/worker.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index daf78118361b..45413d62000d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -211,11 +211,29 @@ def wrapped(key_series, value_series): def wrap_grouped_map_pandas_udf_with_state(f, return_type): """ - # FIXME: document + Provides a new lambda instance wrapping user function of applyInPandasWithState. + + The lambda instance receives (key series, iterator of value series, state) and performs + some conversion to be adapted with the signature of user function. + + See the function doc of inner function `wrapped` for more details on what adapter does. + See the function doc of `mapper` function for + `eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE` for more details on + the input parameters of lambda function. + + Along with the returned iterator, the lambda instance will also produce the return_type as + converted to the arrow schema. """ + def wrapped(key_series, value_series_gen, state): """ - # FIXME: document + Provide an adapter of the user function performing below: + + - Extract the first value of all columns in key series and produce as a tuple. + - If the state has timed out, call the user function with empty pandas DataFrame. + - If not, construct a new generator which converts each element of value series to + pandas DataFrame (lazy evaluation), and call the user function with the generator + - Verify each element of returned iterator to check the schema of pandas DataFrame. """ import pandas as pd @@ -586,7 +604,12 @@ def mapper(a): def mapper(a): """ - # FIXME: document + The function receives (iterator of data, state) and performs extraction of key and + value from the data, with retaining lazy evaluation. + + See `load_stream` in `ApplyInPandasWithStateSerializer` for more details on the input + and see `wrap_grouped_map_pandas_udf_with_state` for more details on how output will + be used. """ from itertools import tee From 514294162e11b88c947f666fe885dd1afb9451ce Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 20 Sep 2022 21:07:18 +0900 Subject: [PATCH 25/38] further document... --- .../ApplyInPandasWithStatePythonRunner.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index ed381a84a56e..0a451508cd80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -91,12 +91,23 @@ class ApplyInPandasWithStatePythonRunner( private val stateRowDeserializer = stateEncoder.createDeserializer() + /** + * This method sends out the additional metadata before sending out actual data. + * + * Specifically, this class overrides this method to also write the schema for state value. + */ override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { super.handleMetadataBeforeExec(stream) // Also write the schema for state value PythonRDD.writeUTF(stateValueSchema.json, stream) } + /** + * Read the (key, state, values) from input iterator and construct Arrow RecordBatches, and + * write constructed RecordBatches to the writer. + * + * See [[ApplyInPandasWithStateWriter]] for more details. + */ protected def writeIteratorToArrowStream( root: VectorSchemaRoot, writer: ArrowStreamWriter, @@ -120,6 +131,10 @@ class ApplyInPandasWithStatePythonRunner( w.finalizeData() } + /** + * Deserialize ColumnarBatch received from the Python worker to produce the output. Schema info + * for given ColumnarBatch is also provided as well. + */ protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OutType = { // This should at least have one row for state. Also, we ensure that all columns across // data and state metadata have same number of rows, which is required by Arrow record From 95a1400284572e841822353c9113b788fa359600 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 20 Sep 2022 21:10:59 +0900 Subject: [PATCH 26/38] further doc --- .../spark/sql/RelationalGroupedDataset.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 69eb8101abf7..0429fd27a41e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -622,6 +622,20 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + /** + * Applies a grouped vectorized python user-defined function to each group of data. + * The user-defined function defines a transformation: iterator of `pandas.DataFrame` -> + * iterator of `pandas.DataFrame`. + * For each group, all elements in the group are passed as an iterator of `pandas.DataFrame` + * along with corresponding state, and the results for all groups are combined into a new + * [[DataFrame]]. + * + * This function does not support partial aggregation, and requires shuffling all the data in + * the [[DataFrame]]. + * + * This function uses Apache Arrow as serialization format between Java executors and Python + * workers. + */ private[sql] def applyInPandasWithState( func: PythonUDF, outputStructType: StructType, From 295bc9b25f12d0f1e944f75a44b2a92de1d34f57 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 20 Sep 2022 21:13:50 +0900 Subject: [PATCH 27/38] apply suggestion --- python/pyspark/sql/pandas/group_ops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index b5628d0d3613..776b174bdf2b 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -254,8 +254,9 @@ def applyInPandasWithState( The `stateStructType` should be :class:`StructType` describing the schema of the user-defined state. The value of the state will be presented as a tuple, as well as the - update should be performed with the tuple. User defined types e.g. native Python class - types are not supported. + update should be performed with the tuple. The corresponding Python types for + :class:DataType are supported. Please refer to the page + https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python tab). The size of each DataFrame in both the input and output can be arbitrary. The number of DataFrames in both the input and output can also be arbitrary. From 9f52c80a8773bd082ac74635b66af3a7d0a314f7 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 21 Sep 2022 05:45:28 +0900 Subject: [PATCH 28/38] update the doc --- .../python/ApplyInPandasWithStateWriter.scala | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala index eed3c16be829..8295847e3c81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala @@ -59,8 +59,8 @@ class ApplyInPandasWithStateWriter( // // ArrowStreamWriter supports only single VectorSchemaRoot, which means all Arrow RecordBatches // being sent out from ArrowStreamWriter should have same schema. That said, we have to construct - // "an" Arrow schema to contain both types of data, and also construct Arrow RecordBatches to - // contain both data. + // "an" Arrow schema to contain both data and state, and also construct ArrowBatches to contain + // both data and state. // // To achieve this, we extend the schema for input data to have a column for state at the end. // But also, we logically group the columns by family (data vs state) and initialize writer @@ -99,6 +99,13 @@ class ApplyInPandasWithStateWriter( // bin-pack and chunk is to check the number of rows in the current Arrow RecordBatch for each // write of row. // + // - Data and State + // + // Since we apply bin-packing and chunking, there should be the way to distinguish each chunk + // from the entire data part of Arrow RecordBatch. We leverage the state metadata to also + // contain the "metadata" of data part to distinguish the chunk from the entire data. + // As a result, state metadata has a 1-1 relationship with "chunk", instead of "grouping key". + // // - Consideration // // Since the number of rows in Arrow RecordBatch does not represent the actual size (bytes), @@ -223,13 +230,31 @@ class ApplyInPandasWithStateWriter( } object ApplyInPandasWithStateWriter { + // This schema contains both state metadata and the metadata of the chunk. Refer the code comment + // of "Data and State" for more details. val STATE_METADATA_SCHEMA: StructType = StructType( Array( + /* + Metadata of the state + */ + + // properties of state instance (excluding state value) in json format StructField("properties", StringType), + // key row as UnsafeRow, Python worker won't touch this value but send the value back to + // executor when sending an update of state StructField("keyRowAsUnsafe", BinaryType), + // state value StructField("object", BinaryType), + + /* + Metadata of the chunk + */ + + // start offset of the data chunk from entire data StructField("startOffset", IntegerType), + // the number of rows for the data chunk StructField("numRows", IntegerType), + // whether the current data chunk is the last one for current grouping key or not StructField("isLastChunk", BooleanType) ) ) From 8d23d46959c63e9555d45d54cdc0be5ddcf4a7d6 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 21 Sep 2022 09:13:22 +0900 Subject: [PATCH 29/38] Add example --- python/pyspark/sql/pandas/group_ops.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 776b174bdf2b..7d49b9fe8fe6 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -282,7 +282,20 @@ def applyInPandasWithState( timeout configuration for groups that do not receive data for a while. valid values are defined in :class:`pyspark.sql.streaming.state.GroupStateTimeout`. - # TODO: Examples + Examples + -------- + >>> import ... + >>> def count_fn(key, pdf_iter, state): + ... assert isinstance(state, GroupStateImpl) + ... total_len = 0 + ... for pdf in pdf_iter: + ... total_len += len(pdf) + ... state.update((total_len,)) + ... yield pandas.DataFrame({"id": [key[0]], "countAsString": [str(total_len)]}) + >>> df.groupby("id").applyInPandasWithState( + ... count_fn, outputStructType="id long, countAsString string", + ... stateStructType="len long", outputMode="Update", + ... timeoutConf=GroupStateTimeout.NoTimeout) # doctest: +SKIP """ from pyspark.sql import GroupedData From fc3dca29ad99e2dffc54d916a3bfbc634f412ba2 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 21 Sep 2022 09:15:22 +0900 Subject: [PATCH 30/38] slight addition --- python/pyspark/sql/pandas/group_ops.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 7d49b9fe8fe6..1d38f0566969 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -296,6 +296,12 @@ def applyInPandasWithState( ... count_fn, outputStructType="id long, countAsString string", ... stateStructType="len long", outputMode="Update", ... timeoutConf=GroupStateTimeout.NoTimeout) # doctest: +SKIP + + Notes + ----- + This function requires a full shuffle. + + This API is experimental. """ from pyspark.sql import GroupedData From 119daf382cd6e5257b9ab2bc64bed547869d1bfe Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 21 Sep 2022 10:29:25 +0900 Subject: [PATCH 31/38] fix style --- python/pyspark/worker.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 45413d62000d..5861330413d2 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -455,8 +455,7 @@ def read_udfs(pickleSer, infile, eval_type): ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: arrow_max_records_per_batch = runner_conf.get( - "spark.sql.execution.arrow.maxRecordsPerBatch", - 10000 + "spark.sql.execution.arrow.maxRecordsPerBatch", 10000 ) arrow_max_records_per_batch = int(arrow_max_records_per_batch) From 4b7c6672be1086ed8870c44fe053ab411ec94c71 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 21 Sep 2022 13:21:52 +0900 Subject: [PATCH 32/38] fix bug during updates of code review --- python/pyspark/sql/pandas/serializers.py | 12 +++++++----- .../python/ApplyInPandasWithStatePythonRunner.scala | 7 ++++--- .../python/FlatMapGroupsInPandasWithStateExec.scala | 2 +- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 8aa27cb1b947..85ee6d02d9ea 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -665,11 +665,6 @@ def serialize_batches(): # This is static and won't change across batches. return_schema = packaged_result[1] - state_pdf = construct_state_pdf(state) - - state_pdfs.append(state_pdf) - state_data_cnt += 1 - for pdf in pdf_iter: # We ignore empty pandas DataFrame. if len(pdf) > 0: @@ -691,6 +686,13 @@ def serialize_batches(): yield batch + # This has to be performed 'after' evaluating all elements in iterator, so that + # the user function has been completed and the state is guaranteed to be updated. + state_pdf = construct_state_pdf(state) + + state_pdfs.append(state_pdf) + state_data_cnt += 1 + # processed all output, but current batch may not be flushed yet. if pdf_data_cnt > 0 or state_data_cnt > 0: batch = construct_record_batch( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index 0a451508cd80..bd8c72029dcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -57,7 +57,7 @@ class ApplyInPandasWithStatePythonRunner( initialWorkerConf: Map[String, String], stateEncoder: ExpressionEncoder[Row], keySchema: StructType, - valueSchema: StructType, + outputSchema: StructType, stateValueSchema: StructType) extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets) with PythonArrowInput[InType] @@ -149,7 +149,8 @@ class ApplyInPandasWithStatePythonRunner( // UDF returns a StructType column in ColumnarBatch, select the children here val structVector = batch.column(ordinal).asInstanceOf[ArrowColumnVector] val dataType = schema(ordinal).dataType.asInstanceOf[StructType] - assert(dataType.sameType(expectedType)) + assert(dataType.sameType(expectedType), + s"Schema equality check failure! type from Arrow: $dataType, expected type: $expectedType") val outputVectors = dataType.indices.map(structVector.getChild) val flattenedBatch = new ColumnarBatch(outputVectors.toArray) @@ -159,7 +160,7 @@ class ApplyInPandasWithStatePythonRunner( } def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = { - val dataBatch = getColumnarBatchForStructTypeColumn(batch, 0, valueSchema) + val dataBatch = getColumnarBatchForStructTypeColumn(batch, 0, outputSchema) dataBatch.rowIterator.asScala.flatMap { row => if (row.isNullAt(0)) { // The entire row in record batch seems to be for state metadata. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 4869902a19eb..159f805f7347 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -155,7 +155,7 @@ case class FlatMapGroupsInPandasWithStateExec( pythonRunnerConf, stateEncoder.asInstanceOf[ExpressionEncoder[Row]], groupingAttributes.toStructType, - child.output.toStructType, + outAttributes.toStructType, stateType) val context = TaskContext.get() From 426f5e71a061592b4fb6616f8afa70d0c0a2cac1 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 21 Sep 2022 14:38:11 +0900 Subject: [PATCH 33/38] fix another bug during code review --- .../python/ApplyInPandasWithStateWriter.scala | 51 ++++++++++--------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala index 8295847e3c81..70ded3dafd7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala @@ -112,13 +112,18 @@ class ApplyInPandasWithStateWriter( // the limit should be set very conservatively. Using a small number of limit does not introduce // correctness issues. - private var numRowsForCurGroup = 0 - private var startOffsetForCurGroup = 0 + // variables for tracking current grouping key and state + private var currentGroupKeyRow: UnsafeRow = _ + private var currentGroupState: GroupStateImpl[Row] = _ + + // variables for tracking the status of current batch private var totalNumRowsForBatch = 0 private var totalNumStatesForBatch = 0 - private var currentGroupKeyRow: UnsafeRow = _ - private var currentGroupState: GroupStateImpl[Row] = _ + // variables for tracking the status of current chunk + private var startOffsetForCurrentChunk = 0 + private var numRowsForCurrentChunk = 0 + /** * Indicates writer to start with new grouping key. @@ -141,17 +146,13 @@ class ApplyInPandasWithStateWriter( // same group, finalize and construct a new batch. if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) { - // Provide state metadata row as intermediate - val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState, - startOffsetForCurGroup, numRowsForCurGroup, isLastChunk = false) - arrowWriterForState.write(stateInfoRow) - totalNumStatesForBatch += 1 - + finalizeCurrentChunk(isLastChunkForGroup = false) finalizeCurrentArrowBatch() } arrowWriterForData.write(dataRow) - numRowsForCurGroup += 1 + + numRowsForCurrentChunk += 1 totalNumRowsForBatch += 1 } @@ -160,22 +161,14 @@ class ApplyInPandasWithStateWriter( * the current group. */ def finalizeGroup(): Unit = { - // Provide state metadata row - val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState, - startOffsetForCurGroup, numRowsForCurGroup, isLastChunk = true) - arrowWriterForState.write(stateInfoRow) - totalNumStatesForBatch += 1 - - // The start offset for next group would be same as the total number of rows for batch, - // unless the next group starts with new batch. - startOffsetForCurGroup = totalNumRowsForBatch + finalizeCurrentChunk(isLastChunkForGroup = true) } /** * Indicates writer that all groups have been processed. */ def finalizeData(): Unit = { - if (numRowsForCurGroup > 0) { + if (totalNumRowsForBatch > 0) { // We still have some rows in the current record batch. Need to finalize them as well. finalizeCurrentArrowBatch() } @@ -210,6 +203,18 @@ class ApplyInPandasWithStateWriter( new GenericInternalRow(Array[Any](stateUnderlyingRow)) } + private def finalizeCurrentChunk(isLastChunkForGroup: Boolean): Unit = { + val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState, + startOffsetForCurrentChunk, numRowsForCurrentChunk, isLastChunkForGroup) + arrowWriterForState.write(stateInfoRow) + totalNumStatesForBatch += 1 + + // The start offset for next chunk would be same as the total number of rows for batch, + // unless the next chunk starts with new batch. + startOffsetForCurrentChunk = totalNumRowsForBatch + numRowsForCurrentChunk = 0 + } + private def finalizeCurrentArrowBatch(): Unit = { val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch (0 until remainingEmptyStateRows).foreach { _ => @@ -222,8 +227,8 @@ class ApplyInPandasWithStateWriter( arrowWriterForState.reset() arrowWriterForData.reset() - startOffsetForCurGroup = 0 - numRowsForCurGroup = 0 + startOffsetForCurrentChunk = 0 + numRowsForCurrentChunk = 0 totalNumRowsForBatch = 0 totalNumStatesForBatch = 0 } From 83f2555fe838be1ccb0a45d6c20cba56f45debfb Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 21 Sep 2022 16:08:54 +0900 Subject: [PATCH 34/38] fix on pydoc --- python/pyspark/sql/pandas/group_ops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 1d38f0566969..4ac5cfcbdc74 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -284,14 +284,15 @@ def applyInPandasWithState( Examples -------- - >>> import ... + >>> import pandas as pd # doctest: +SKIP + >>> from pyspark.sql.streaming.state import GroupStateTimeout >>> def count_fn(key, pdf_iter, state): ... assert isinstance(state, GroupStateImpl) ... total_len = 0 ... for pdf in pdf_iter: ... total_len += len(pdf) ... state.update((total_len,)) - ... yield pandas.DataFrame({"id": [key[0]], "countAsString": [str(total_len)]}) + ... yield pd.DataFrame({"id": [key[0]], "countAsString": [str(total_len)]}) >>> df.groupby("id").applyInPandasWithState( ... count_fn, outputStructType="id long, countAsString string", ... stateStructType="len long", outputMode="Update", From 38eec2d0f1ea1b77e443097906ea41f711c9cbd2 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 21 Sep 2022 17:51:19 +0900 Subject: [PATCH 35/38] Fix a silly bug where the value of the state is removed or not initialized yet --- python/pyspark/sql/pandas/serializers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 85ee6d02d9ea..928d6f9d84a2 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -578,7 +578,10 @@ def construct_state_pdf(state): state_properties = state.json().encode("utf-8") state_key_row_as_binary = state._keyAsUnsafe - state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) + if state.exists: + state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) + else: + state_object = None state_old_timeout_timestamp = state.oldTimeoutTimestamp state_dict = { From 8133dcd29c7008e4c1eedb5065bcec6140c6df86 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 21 Sep 2022 22:40:29 +0900 Subject: [PATCH 36/38] fix an edge-case being figured out from newer test case --- .../sql/execution/python/ApplyInPandasWithStateWriter.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala index 70ded3dafd7e..60a228ddd73a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala @@ -162,6 +162,12 @@ class ApplyInPandasWithStateWriter( */ def finalizeGroup(): Unit = { finalizeCurrentChunk(isLastChunkForGroup = true) + + // If it exceeds the condition of batch (number of records) once the all data is received for + // same group, finalize and construct a new batch. + if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) { + finalizeCurrentArrowBatch() + } } /** From f1000487960fa19aff9979211db68e63ec4384e0 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 21 Sep 2022 22:50:15 +0900 Subject: [PATCH 37/38] loosen the requirement --- python/pyspark/sql/pandas/group_ops.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 4ac5cfcbdc74..05c5b49c81c8 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -242,9 +242,10 @@ def applyInPandasWithState( For each group, all columns are passed together as `pandas.DataFrame` to the user-function, and the returned `pandas.DataFrame` across all invocations are combined as a - :class:`DataFrame`. Note that the user function should loop through and process all - elements in the iterator. The user function should not make a guess of the number of - elements in the iterator. + :class:`DataFrame`. Note that the user function should not make a guess of the number of + elements in the iterator. To process all data, the user function needs to iterate all + elements and process them. On the other hand, the user function is not strictly required to + iterate through all elements in the iterator if it intends to read a part of data. The `outputStructType` should be a :class:`StructType` describing the schema of all elements in the returned value, `pandas.DataFrame`. The column labels of all elements in From dd7a6557b42cb0bccb5fce66ab7ac01c7abb510c Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 22 Sep 2022 05:42:46 +0900 Subject: [PATCH 38/38] reflect feedbacks --- python/pyspark/sql/pandas/group_ops.py | 2 +- python/pyspark/sql/pandas/serializers.py | 4 ++-- python/pyspark/worker.py | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 05c5b49c81c8..0945c0078a2a 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -236,7 +236,7 @@ def applyInPandasWithState( state will be saved across invocations. The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and - returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple + return another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as :class:`pyspark.sql.streaming.state.GroupState`. diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 928d6f9d84a2..ca249c75ea5c 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -473,8 +473,8 @@ def gen_data_and_state(batches): 3.A. Extract the data out from entire data via the information of data range. 3.B. Construct a new state instance if the state information is the first occurrence for the current grouping key. - 3.C. Leverage existing new state instance if the state instance is already available - for the current grouping key. (Meaning it's not the first occurrence.) + 3.C. Leverage the existing state instance if it is already available for the current + grouping key. (Meaning it's not the first occurrence.) 3.D. Remove the cache of state instance if the state information denotes the data is the last chunk for current grouping key. diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 5861330413d2..c1c3669701f7 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -258,7 +258,8 @@ def verify_element(result): # the number of columns of result have to match the return type # but it is fine for result to have no columns at all if it is empty if not ( - len(result.columns) == len(return_type) or len(result.columns) == 0 and result.empty + len(result.columns) == len(return_type) + or (len(result.columns) == 0 and result.empty) ): raise RuntimeError( "Number of columns of the element (pandas.DataFrame) in return iterator "