diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 5a13674e8bfb..7b31fa93c32e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -53,6 +53,7 @@ private[spark] object PythonEvalType { val SQL_MAP_PANDAS_ITER_UDF = 205 val SQL_COGROUPED_MAP_PANDAS_UDF = 206 val SQL_MAP_ARROW_ITER_UDF = 207 + val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -65,6 +66,7 @@ private[spark] object PythonEvalType { case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF" case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF" case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF" + case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE" } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7ef0014ae751..5f4f4d494e13 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -105,6 +105,7 @@ PandasMapIterUDFType, PandasCogroupedMapUDFType, ArrowMapIterUDFType, + PandasGroupedMapUDFWithStateType, ) from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import AtomicType, StructType @@ -147,6 +148,7 @@ class PythonEvalType: SQL_MAP_PANDAS_ITER_UDF: "PandasMapIterUDFType" = 205 SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206 SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207 + SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208 def portable_hash(x: Hashable) -> int: diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi index 27ac64a7238b..acca8c00f2aa 100644 --- a/python/pyspark/sql/pandas/_typing/__init__.pyi +++ b/python/pyspark/sql/pandas/_typing/__init__.pyi @@ -30,6 +30,7 @@ from typing_extensions import Protocol, Literal from types import FunctionType from pyspark.sql._typing import LiteralType +from pyspark.sql.streaming.state import GroupState from pandas.core.frame import DataFrame as PandasDataFrame from pandas.core.series import Series as PandasSeries from numpy import ndarray as NDArray @@ -51,6 +52,7 @@ PandasScalarIterUDFType = Literal[204] PandasMapIterUDFType = Literal[205] PandasCogroupedMapUDFType = Literal[206] ArrowMapIterUDFType = Literal[207] +PandasGroupedMapUDFWithStateType = Literal[208] class PandasVariadicScalarToScalarFunction(Protocol): def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ... @@ -256,6 +258,10 @@ PandasGroupedMapFunction = Union[ Callable[[Any, DataFrameLike], DataFrameLike], ] +PandasGroupedMapFunctionWithState = Callable[ + [Any, Iterable[DataFrameLike], GroupState], Iterable[DataFrameLike] +] + class PandasVariadicGroupedAggFunction(Protocol): def __call__(self, *_: SeriesLike) -> LiteralType: ... diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 94fabdbb2959..d0f81e2f6335 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -369,6 +369,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, None, ]: # None means it should infer the type from type hints. @@ -402,6 +403,7 @@ def _create_pandas_udf(f, returnType, evalType): PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, ]: # In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is being triggered # at `apply` instead. diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 6178433573e9..0945c0078a2a 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -15,18 +15,20 @@ # limitations under the License. # import sys -from typing import List, Union, TYPE_CHECKING +from typing import List, Union, TYPE_CHECKING, cast import warnings from pyspark.rdd import PythonEvalType from pyspark.sql.column import Column from pyspark.sql.dataframe import DataFrame -from pyspark.sql.types import StructType +from pyspark.sql.streaming.state import GroupStateTimeout +from pyspark.sql.types import StructType, _parse_datatype_string if TYPE_CHECKING: from pyspark.sql.pandas._typing import ( GroupedMapPandasUserDefinedFunction, PandasGroupedMapFunction, + PandasGroupedMapFunctionWithState, PandasCogroupedMapFunction, ) from pyspark.sql.group import GroupedData @@ -216,6 +218,125 @@ def applyInPandas( jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) return DataFrame(jdf, self.session) + def applyInPandasWithState( + self, + func: "PandasGroupedMapFunctionWithState", + outputStructType: Union[StructType, str], + stateStructType: Union[StructType, str], + outputMode: str, + timeoutConf: str, + ) -> DataFrame: + """ + Applies the given function to each group of data, while maintaining a user-defined + per-group state. The result Dataset will represent the flattened record returned by the + function. + + For a streaming Dataset, the function will be invoked first for all input groups and then + for all timed out states where the input data is set to be empty. Updates to each group's + state will be saved across invocations. + + The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and + return another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple + of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as + :class:`pyspark.sql.streaming.state.GroupState`. + + For each group, all columns are passed together as `pandas.DataFrame` to the user-function, + and the returned `pandas.DataFrame` across all invocations are combined as a + :class:`DataFrame`. Note that the user function should not make a guess of the number of + elements in the iterator. To process all data, the user function needs to iterate all + elements and process them. On the other hand, the user function is not strictly required to + iterate through all elements in the iterator if it intends to read a part of data. + + The `outputStructType` should be a :class:`StructType` describing the schema of all + elements in the returned value, `pandas.DataFrame`. The column labels of all elements in + returned `pandas.DataFrame` must either match the field names in the defined schema if + specified as strings, or match the field data types by position if not strings, + e.g. integer indices. + + The `stateStructType` should be :class:`StructType` describing the schema of the + user-defined state. The value of the state will be presented as a tuple, as well as the + update should be performed with the tuple. The corresponding Python types for + :class:DataType are supported. Please refer to the page + https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python tab). + + The size of each DataFrame in both the input and output can be arbitrary. The number of + DataFrames in both the input and output can also be arbitrary. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + func : function + a Python native function to be called on every group. It should take parameters + (key, Iterator[`pandas.DataFrame`], state) and return Iterator[`pandas.DataFrame`]. + Note that the type of the key is tuple and the type of the state is + :class:`pyspark.sql.streaming.state.GroupState`. + outputStructType : :class:`pyspark.sql.types.DataType` or str + the type of the output records. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + stateStructType : :class:`pyspark.sql.types.DataType` or str + the type of the user-defined state. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + outputMode : str + the output mode of the function. + timeoutConf : str + timeout configuration for groups that do not receive data for a while. valid values + are defined in :class:`pyspark.sql.streaming.state.GroupStateTimeout`. + + Examples + -------- + >>> import pandas as pd # doctest: +SKIP + >>> from pyspark.sql.streaming.state import GroupStateTimeout + >>> def count_fn(key, pdf_iter, state): + ... assert isinstance(state, GroupStateImpl) + ... total_len = 0 + ... for pdf in pdf_iter: + ... total_len += len(pdf) + ... state.update((total_len,)) + ... yield pd.DataFrame({"id": [key[0]], "countAsString": [str(total_len)]}) + >>> df.groupby("id").applyInPandasWithState( + ... count_fn, outputStructType="id long, countAsString string", + ... stateStructType="len long", outputMode="Update", + ... timeoutConf=GroupStateTimeout.NoTimeout) # doctest: +SKIP + + Notes + ----- + This function requires a full shuffle. + + This API is experimental. + """ + + from pyspark.sql import GroupedData + from pyspark.sql.functions import pandas_udf + + assert isinstance(self, GroupedData) + assert timeoutConf in [ + GroupStateTimeout.NoTimeout, + GroupStateTimeout.ProcessingTimeTimeout, + GroupStateTimeout.EventTimeTimeout, + ] + + if isinstance(outputStructType, str): + outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) + if isinstance(stateStructType, str): + stateStructType = cast(StructType, _parse_datatype_string(stateStructType)) + + udf = pandas_udf( + func, # type: ignore[call-overload] + returnType=outputStructType, + functionType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + ) + df = self._df + udf_column = udf(*[df[col] for col in df.columns]) + jdf = self._jgd.applyInPandasWithState( + udf_column._jc.expr(), + self.session._jsparkSession.parseDataType(outputStructType.json()), + self.session._jsparkSession.parseDataType(stateStructType.json()), + outputMode, + timeoutConf, + ) + return DataFrame(jdf, self.session) + def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps": """ Cogroups this group with another group so that we can run cogrouped operations. diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 992e82b403a1..ca249c75ea5c 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -19,7 +19,9 @@ Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details. """ -from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer +from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer +from pyspark.sql.pandas.types import to_arrow_type +from pyspark.sql.types import StringType, StructType, BinaryType, StructField, LongType class SpecialLengths: @@ -371,3 +373,354 @@ def load_stream(self, stream): raise ValueError( "Invalid number of pandas.DataFrames in group {0}".format(dataframes_in_group) ) + + +class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): + """ + Serializer used by Python worker to evaluate UDF for applyInPandasWithState. + + Parameters + ---------- + timezone : str + A timezone to respect when handling timestamp values + safecheck : bool + If True, conversion from Arrow to Pandas checks for overflow/truncation + assign_cols_by_name : bool + If True, then Pandas DataFrames will get columns by name + state_object_schema : StructType + The type of state object represented as Spark SQL type + arrow_max_records_per_batch : int + Limit of the number of records that can be written to a single ArrowRecordBatch in memory. + """ + + def __init__( + self, + timezone, + safecheck, + assign_cols_by_name, + state_object_schema, + arrow_max_records_per_batch, + ): + super(ApplyInPandasWithStateSerializer, self).__init__( + timezone, safecheck, assign_cols_by_name + ) + self.pickleSer = CPickleSerializer() + self.utf8_deserializer = UTF8Deserializer() + self.state_object_schema = state_object_schema + + self.result_state_df_type = StructType( + [ + StructField("properties", StringType()), + StructField("keyRowAsUnsafe", BinaryType()), + StructField("object", BinaryType()), + StructField("oldTimeoutTimestamp", LongType()), + ] + ) + + self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type) + self.arrow_max_records_per_batch = arrow_max_records_per_batch + + def load_stream(self, stream): + """ + Read ArrowRecordBatches from stream, deserialize them to populate a list of pair + (data chunk, state), and convert the data into a list of pandas.Series. + + Please refer the doc of inner function `gen_data_and_state` for more details how + this function works in overall. + + In addition, this function further groups the return of `gen_data_and_state` by the state + instance (same semantic as grouping by grouping key) and produces an iterator of data + chunks for each group, so that the caller can lazily materialize the data chunk. + """ + + import pyarrow as pa + import json + from itertools import groupby + from pyspark.sql.streaming.state import GroupState + + def construct_state(state_info_col): + """ + Construct state instance from the value of state information column. + """ + + state_info_col_properties = state_info_col["properties"] + state_info_col_key_row = state_info_col["keyRowAsUnsafe"] + state_info_col_object = state_info_col["object"] + + state_properties = json.loads(state_info_col_properties) + if state_info_col_object: + state_object = self.pickleSer.loads(state_info_col_object) + else: + state_object = None + state_properties["optionalValue"] = state_object + + return GroupState( + keyAsUnsafe=state_info_col_key_row, + valueSchema=self.state_object_schema, + **state_properties, + ) + + def gen_data_and_state(batches): + """ + Deserialize ArrowRecordBatches and return a generator of + `(a list of pandas.Series, state)`. + + The logic on deserialization is following: + + 1. Read the entire data part from Arrow RecordBatch. + 2. Read the entire state information part from Arrow RecordBatch. + 3. Loop through each state information: + 3.A. Extract the data out from entire data via the information of data range. + 3.B. Construct a new state instance if the state information is the first occurrence + for the current grouping key. + 3.C. Leverage the existing state instance if it is already available for the current + grouping key. (Meaning it's not the first occurrence.) + 3.D. Remove the cache of state instance if the state information denotes the data is + the last chunk for current grouping key. + + This deserialization logic assumes that Arrow RecordBatches contain the data with the + ordering that data chunks for same grouping key will appear sequentially. + + This function must avoid materializing multiple Arrow RecordBatches into memory at the + same time. And data chunks from the same grouping key should appear sequentially, to + further group them based on state instance (same state instance will be produced for + same grouping key). + """ + + state_for_current_group = None + + for batch in batches: + batch_schema = batch.schema + data_schema = pa.schema([batch_schema[i] for i in range(0, len(batch_schema) - 1)]) + state_schema = pa.schema( + [ + batch_schema[-1], + ] + ) + + batch_columns = batch.columns + data_columns = batch_columns[0:-1] + state_column = batch_columns[-1] + + data_batch = pa.RecordBatch.from_arrays(data_columns, schema=data_schema) + state_batch = pa.RecordBatch.from_arrays( + [ + state_column, + ], + schema=state_schema, + ) + + state_arrow = pa.Table.from_batches([state_batch]).itercolumns() + state_pandas = [self.arrow_to_pandas(c) for c in state_arrow][0] + + for state_idx in range(0, len(state_pandas)): + state_info_col = state_pandas.iloc[state_idx] + + if not state_info_col: + # no more data with grouping key + state + break + + data_start_offset = state_info_col["startOffset"] + num_data_rows = state_info_col["numRows"] + is_last_chunk = state_info_col["isLastChunk"] + + if state_for_current_group: + # use the state, we already have state for same group and there should be + # some data in same group being processed earlier + state = state_for_current_group + else: + # there is no state being stored for same group, construct one + state = construct_state(state_info_col) + + if is_last_chunk: + # discard the state being cached for same group + state_for_current_group = None + elif not state_for_current_group: + # there's no cached state but expected to have additional data in same group + # cache the current state + state_for_current_group = state + + data_batch_for_group = data_batch.slice(data_start_offset, num_data_rows) + data_arrow = pa.Table.from_batches([data_batch_for_group]).itercolumns() + + data_pandas = [self.arrow_to_pandas(c) for c in data_arrow] + + # state info + yield ( + data_pandas, + state, + ) + + _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + + data_state_generator = gen_data_and_state(_batches) + + # state will be same object for same grouping key + for _state, _data in groupby(data_state_generator, key=lambda x: x[1]): + yield ( + _data, + _state, + ) + + def dump_stream(self, iterator, stream): + """ + Read through an iterator of (iterator of pandas DataFrame, state), serialize them to Arrow + RecordBatches, and write batches to stream. + """ + + import pandas as pd + import pyarrow as pa + + def construct_state_pdf(state): + """ + Construct a pandas DataFrame from the state instance. + """ + + state_properties = state.json().encode("utf-8") + state_key_row_as_binary = state._keyAsUnsafe + if state.exists: + state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) + else: + state_object = None + state_old_timeout_timestamp = state.oldTimeoutTimestamp + + state_dict = { + "properties": [ + state_properties, + ], + "keyRowAsUnsafe": [ + state_key_row_as_binary, + ], + "object": [ + state_object, + ], + "oldTimeoutTimestamp": [ + state_old_timeout_timestamp, + ], + } + + return pd.DataFrame.from_dict(state_dict) + + def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_data_cnt): + """ + Construct a new Arrow RecordBatch based on output pandas DataFrames and states. Each + one matches to the single struct field for Arrow schema, hence the return value of + Arrow RecordBatch will have schema with two fields, in `data`, `state` order. + (Readers are expected to access the field via position rather than the name. We do + not guarantee the name of the field.) + + Note that Arrow RecordBatch requires all columns to have all same number of rows, + hence this function inserts empty data for state/data with less elements to compensate. + """ + + max_data_cnt = max(pdf_data_cnt, state_data_cnt) + + empty_row_cnt_in_data = max_data_cnt - pdf_data_cnt + empty_row_cnt_in_state = max_data_cnt - state_data_cnt + + empty_rows_pdf = pd.DataFrame( + dict.fromkeys(pa.schema(pdf_schema).names), + index=[x for x in range(0, empty_row_cnt_in_data)], + ) + empty_rows_state = pd.DataFrame( + columns=["properties", "keyRowAsUnsafe", "object", "oldTimeoutTimestamp"], + index=[x for x in range(0, empty_row_cnt_in_state)], + ) + + pdfs.append(empty_rows_pdf) + state_pdfs.append(empty_rows_state) + + merged_pdf = pd.concat(pdfs, ignore_index=True) + merged_state_pdf = pd.concat(state_pdfs, ignore_index=True) + + return self._create_batch( + [(merged_pdf, pdf_schema), (merged_state_pdf, self.result_state_pdf_arrow_type)] + ) + + def serialize_batches(): + """ + Read through an iterator of (iterator of pandas DataFrame, state), and serialize them + to Arrow RecordBatches. + + This function does batching on constructing the Arrow RecordBatch; a batch will be + serialized to the Arrow RecordBatch when the total number of records exceeds the + configured threshold. + """ + # a set of variables for the state of current batch which will be converted to Arrow + # RecordBatch. + pdfs = [] + state_pdfs = [] + pdf_data_cnt = 0 + state_data_cnt = 0 + + return_schema = None + + for data in iterator: + # data represents the result of each call of user function + packaged_result = data[0] + + # There are two results from the call of user function: + # 1) iterator of pandas DataFrame (output) + # 2) updated state instance + pdf_iter = packaged_result[0][0] + state = packaged_result[0][1] + + # This is static and won't change across batches. + return_schema = packaged_result[1] + + for pdf in pdf_iter: + # We ignore empty pandas DataFrame. + if len(pdf) > 0: + pdf_data_cnt += len(pdf) + pdfs.append(pdf) + + # If the total number of records in current batch exceeds the configured + # threshold, time to construct the Arrow RecordBatch from the batch. + if pdf_data_cnt > self.arrow_max_records_per_batch: + batch = construct_record_batch( + pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt + ) + + # Reset the variables to start with new batch for further data. + pdfs = [] + state_pdfs = [] + pdf_data_cnt = 0 + state_data_cnt = 0 + + yield batch + + # This has to be performed 'after' evaluating all elements in iterator, so that + # the user function has been completed and the state is guaranteed to be updated. + state_pdf = construct_state_pdf(state) + + state_pdfs.append(state_pdf) + state_data_cnt += 1 + + # processed all output, but current batch may not be flushed yet. + if pdf_data_cnt > 0 or state_data_cnt > 0: + batch = construct_record_batch( + pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt + ) + + yield batch + + def init_stream_yield_batches(batches): + """ + This function helps to ensure the requirement for Pandas UDFs - Pandas UDFs require a + START_ARROW_STREAM before the Arrow stream is sent. + + START_ARROW_STREAM should be sent after creating the first record batch so in case of + an error, it can be sent back to the JVM before the Arrow stream starts. + """ + should_write_start_length = True + + for batch in batches: + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + + yield batch + + batches_to_write = init_stream_yield_batches(serialize_batches()) + + return ArrowStreamSerializer.dump_stream(self, batches_to_write, stream) diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py index 842eff322330..66b225e1b10c 100644 --- a/python/pyspark/sql/streaming/state.py +++ b/python/pyspark/sql/streaming/state.py @@ -20,16 +20,24 @@ from pyspark.sql.types import DateType, Row, StructType -__all__ = ["GroupStateImpl", "GroupStateTimeout"] +__all__ = ["GroupState", "GroupStateTimeout"] class GroupStateTimeout: + """ + Represents the type of timeouts possible for the Dataset operations applyInPandasWithState. + """ + NoTimeout: str = "NoTimeout" ProcessingTimeTimeout: str = "ProcessingTimeTimeout" EventTimeTimeout: str = "EventTimeTimeout" -class GroupStateImpl: +class GroupState: + """ + Wrapper class for interacting with per-group state data in `applyInPandasWithState`. + """ + NO_TIMESTAMP: int = -1 def __init__( @@ -76,10 +84,16 @@ def __init__( @property def exists(self) -> bool: + """ + Whether state exists or not. + """ return self._defined @property def get(self) -> Tuple: + """ + Get the state value if it exists, or throw ValueError. + """ if self.exists: return tuple(self._value) else: @@ -87,6 +101,9 @@ def get(self) -> Tuple: @property def getOption(self) -> Optional[Tuple]: + """ + Get the state value if it exists, or return None. + """ if self.exists: return tuple(self._value) else: @@ -94,6 +111,10 @@ def getOption(self) -> Optional[Tuple]: @property def hasTimedOut(self) -> bool: + """ + Whether the function has been called because the key has timed out. + This can return true only when timeouts are enabled. + """ return self._has_timed_out # NOTE: this function is only available to PySpark implementation due to underlying @@ -103,6 +124,9 @@ def oldTimeoutTimestamp(self) -> int: return self._old_timeout_timestamp def update(self, newValue: Tuple) -> None: + """ + Update the value of the state. The value of the state cannot be null. + """ if newValue is None: raise ValueError("'None' is not a valid state value") @@ -112,11 +136,18 @@ def update(self, newValue: Tuple) -> None: self._removed = False def remove(self) -> None: + """ + Remove this state. + """ self._defined = False self._updated = False self._removed = True def setTimeoutDuration(self, durationMs: int) -> None: + """ + Set the timeout duration in ms for this key. + Processing time timeout must be enabled. + """ if isinstance(durationMs, str): # TODO(SPARK-40437): Support string representation of durationMs. raise ValueError("durationMs should be int but get :%s" % type(durationMs)) @@ -133,6 +164,11 @@ def setTimeoutDuration(self, durationMs: int) -> None: # TODO(SPARK-40438): Implement additionalDuration parameter. def setTimeoutTimestamp(self, timestampMs: int) -> None: + """ + Set the timeout timestamp for this key as milliseconds in epoch time. + This timestamp cannot be older than the current watermark. + Event time timeout must be enabled. + """ if self._timeout_conf != GroupStateTimeout.EventTimeTimeout: raise RuntimeError( "Cannot set timeout duration without enabling processing time timeout in " @@ -146,7 +182,7 @@ def setTimeoutTimestamp(self, timestampMs: int) -> None: raise ValueError("Timeout timestamp must be positive") if ( - self._event_time_watermark_ms != GroupStateImpl.NO_TIMESTAMP + self._event_time_watermark_ms != GroupState.NO_TIMESTAMP and timestampMs < self._event_time_watermark_ms ): raise ValueError( @@ -157,6 +193,10 @@ def setTimeoutTimestamp(self, timestampMs: int) -> None: self._timeout_timestamp = timestampMs def getCurrentWatermarkMs(self) -> int: + """ + Get the current event time watermark as milliseconds in epoch time. + In a streaming query, this can be called only when watermark is set. + """ if not self._watermark_present: raise RuntimeError( "Cannot get event time watermark timestamp without setting watermark before " @@ -165,6 +205,11 @@ def getCurrentWatermarkMs(self) -> int: return self._event_time_watermark_ms def getCurrentProcessingTimeMs(self) -> int: + """ + Get the current processing time as milliseconds in epoch time. + In a streaming query, this will return a constant value throughout the duration of a + trigger, even if the trigger is re-executed. + """ return self._batch_processing_time_ms def __str__(self) -> str: @@ -174,6 +219,10 @@ def __str__(self) -> str: return "GroupState()" def json(self) -> str: + """ + Convert the internal values of instance into JSON. This is used to send out the update + from Python worker to executor. + """ return json.dumps( { # Constructor diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 6a01e399d040..da9a245bb711 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -144,20 +144,23 @@ def returnType(self) -> DataType: "Invalid return type with scalar Pandas UDFs: %s is " "not supported" % str(self._returnType_placeholder) ) - elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + elif ( + self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF + or self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE + ): if isinstance(self._returnType_placeholder, StructType): try: to_arrow_type(self._returnType_placeholder) except TypeError: raise NotImplementedError( "Invalid return type with grouped map Pandas UDFs or " - "at groupby.applyInPandas: %s is not supported" + "at groupby.applyInPandas(WithState): %s is not supported" % str(self._returnType_placeholder) ) else: raise TypeError( "Invalid return type for grouped map Pandas " - "UDFs or at groupby.applyInPandas: return type must be a " + "UDFs or at groupby.applyInPandas(WithState): return type must be a " "StructType." ) elif ( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index c486b7bed1d8..c1c3669701f7 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,6 +23,7 @@ import time from inspect import currentframe, getframeinfo, getfullargspec import importlib +import json # 'resource' is a Unix specific module. has_resource_module = True @@ -57,6 +58,7 @@ ArrowStreamPandasUDFSerializer, CogroupUDFSerializer, ArrowStreamUDFSerializer, + ApplyInPandasWithStateSerializer, ) from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import StructType @@ -207,6 +209,90 @@ def wrapped(key_series, value_series): return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] +def wrap_grouped_map_pandas_udf_with_state(f, return_type): + """ + Provides a new lambda instance wrapping user function of applyInPandasWithState. + + The lambda instance receives (key series, iterator of value series, state) and performs + some conversion to be adapted with the signature of user function. + + See the function doc of inner function `wrapped` for more details on what adapter does. + See the function doc of `mapper` function for + `eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE` for more details on + the input parameters of lambda function. + + Along with the returned iterator, the lambda instance will also produce the return_type as + converted to the arrow schema. + """ + + def wrapped(key_series, value_series_gen, state): + """ + Provide an adapter of the user function performing below: + + - Extract the first value of all columns in key series and produce as a tuple. + - If the state has timed out, call the user function with empty pandas DataFrame. + - If not, construct a new generator which converts each element of value series to + pandas DataFrame (lazy evaluation), and call the user function with the generator + - Verify each element of returned iterator to check the schema of pandas DataFrame. + """ + import pandas as pd + + key = tuple(s[0] for s in key_series) + + if state.hasTimedOut: + # Timeout processing pass empty iterator. Here we return an empty DataFrame instead. + values = [ + pd.DataFrame(columns=pd.concat(next(value_series_gen), axis=1).columns), + ] + else: + values = (pd.concat(x, axis=1) for x in value_series_gen) + + result_iter = f(key, values, state) + + def verify_element(result): + if not isinstance(result, pd.DataFrame): + raise TypeError( + "The type of element in return iterator of the user-defined function " + "should be pandas.DataFrame, but is {}".format(type(result)) + ) + # the number of columns of result have to match the return type + # but it is fine for result to have no columns at all if it is empty + if not ( + len(result.columns) == len(return_type) + or (len(result.columns) == 0 and result.empty) + ): + raise RuntimeError( + "Number of columns of the element (pandas.DataFrame) in return iterator " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(return_type), len(result.columns)) + ) + + return result + + if isinstance(result_iter, pd.DataFrame): + raise TypeError( + "Return type of the user-defined function should be " + "iterable of pandas.DataFrame, but is {}".format(type(result_iter)) + ) + + try: + iter(result_iter) + except TypeError: + raise TypeError( + "Return type of the user-defined function should be " + "iterable, but is {}".format(type(result_iter)) + ) + + result_iter_with_validation = (verify_element(x) for x in result_iter) + + return ( + result_iter_with_validation, + state, + ) + + return lambda k, v, s: [(wrapped(k, v, s), to_arrow_type(return_type))] + + def wrap_grouped_agg_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) @@ -311,6 +397,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec) @@ -336,6 +424,7 @@ def read_udfs(pickleSer, infile, eval_type): PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, ): # Load conf used for pandas_udf evaluation @@ -345,6 +434,10 @@ def read_udfs(pickleSer, infile, eval_type): v = utf8_deserializer.loads(infile) runner_conf[k] = v + state_object_schema = None + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) + # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = ( @@ -361,6 +454,19 @@ def read_udfs(pickleSer, infile, eval_type): if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + arrow_max_records_per_batch = runner_conf.get( + "spark.sql.execution.arrow.maxRecordsPerBatch", 10000 + ) + arrow_max_records_per_batch = int(arrow_max_records_per_batch) + + ser = ApplyInPandasWithStateSerializer( + timezone, + safecheck, + assign_cols_by_name, + state_object_schema, + arrow_max_records_per_batch, + ) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: ser = ArrowStreamUDFSerializer() else: @@ -486,6 +592,43 @@ def mapper(a): vals = [a[o] for o in parsed_offsets[0][1]] return f(keys, vals) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + # We assume there is only one UDF here because grouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + + # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to + # distinguish between grouping attributes and data attributes + arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + parsed_offsets = extract_key_value_indexes(arg_offsets) + + def mapper(a): + """ + The function receives (iterator of data, state) and performs extraction of key and + value from the data, with retaining lazy evaluation. + + See `load_stream` in `ApplyInPandasWithStateSerializer` for more details on the input + and see `wrap_grouped_map_pandas_udf_with_state` for more details on how output will + be used. + """ + from itertools import tee + + state = a[1] + data_gen = (x[0] for x in a[0]) + + # We know there should be at least one item in the iterator/generator. + # We want to peek the first element to construct the key, hence applying + # tee to construct the key while we retain another iterator/generator + # for values. + keys_gen, values_gen = tee(data_gen) + keys_elem = next(keys_gen) + keys = [keys_elem[o] for o in parsed_offsets[0][0]] + + # This must be generator comprehension - do not materialize. + vals = ([x[o] for o in parsed_offsets[0][1]] for x in values_gen) + + return f(keys, vals, state) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: # We assume there is only one UDF here because cogrouped map doesn't # support combining multiple UDFs. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index c11ce7d3b90f..84795203fd17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -64,6 +64,7 @@ object UnsupportedOperationChecker extends Logging { case s: Aggregate if s.isStreaming => true case _ @ Join(left, right, _, _, _) if left.isStreaming && right.isStreaming => true case f: FlatMapGroupsWithState if f.isStreaming => true + case f: FlatMapGroupsInPandasWithState if f.isStreaming => true case d: Deduplicate if d.isStreaming => true case _ => false } @@ -142,6 +143,17 @@ object UnsupportedOperationChecker extends Logging { " or the output mode is not append on a streaming DataFrames/Datasets")(plan) } + val applyInPandasWithStates = plan.collect { + case f: FlatMapGroupsInPandasWithState if f.isStreaming => f + } + + // Disallow multiple `applyInPandasWithState`s. + if (applyInPandasWithStates.size > 1) { + throwError( + "Multiple applyInPandasWithStates are not supported on a streaming " + + "DataFrames/Datasets")(plan) + } + // Disallow multiple streaming aggregations val aggregates = collectStreamingAggregates(plan) @@ -311,6 +323,56 @@ object UnsupportedOperationChecker extends Logging { } } + // applyInPandasWithState + case m: FlatMapGroupsInPandasWithState if m.isStreaming => + // Check compatibility with output modes and aggregations in query + val aggsInQuery = collectStreamingAggregates(plan) + + if (aggsInQuery.isEmpty) { + // applyInPandasWithState without aggregation: operation's output mode must + // match query output mode + m.outputMode match { + case InternalOutputModes.Update if outputMode != InternalOutputModes.Update => + throwError( + "applyInPandasWithState in update mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case InternalOutputModes.Append if outputMode != InternalOutputModes.Append => + throwError( + "applyInPandasWithState in append mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case _ => + } + } else { + // applyInPandasWithState with aggregation: update operation mode not allowed, and + // *groupsWithState after aggregation not allowed + if (m.outputMode == InternalOutputModes.Update) { + throwError( + "applyInPandasWithState in update mode is not supported with " + + "aggregation on a streaming DataFrame/Dataset") + } else if (collectStreamingAggregates(m).nonEmpty) { + throwError( + "applyInPandasWithState in append mode is not supported after " + + "aggregation on a streaming DataFrame/Dataset") + } + } + + // Check compatibility with timeout configs + if (m.timeout == EventTimeTimeout) { + // With event time timeout, watermark must be defined. + val watermarkAttributes = m.child.output.collect { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a + } + if (watermarkAttributes.isEmpty) { + throwError( + "Watermark must be specified in the query using " + + "'[Dataset/DataFrame].withWatermark()' for using event-time timeout in a " + + "applyInPandasWithState. Event-time timeout not supported without " + + "watermark.")(plan) + } + } + case d: Deduplicate if collectStreamingAggregates(d).nonEmpty => throwError("dropDuplicates is not supported after aggregation on a " + "streaming DataFrame/Dataset") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index c2f74b350834..e97ff7808f17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.types.StructType /** * FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame. @@ -98,6 +100,38 @@ case class FlatMapCoGroupsInPandas( copy(left = newLeft, right = newRight) } +/** + * Similar with [[FlatMapGroupsWithState]]. Applies func to each unique group + * in `child`, based on the evaluation of `groupingAttributes`, + * while using state data. + * `functionExpr` is invoked with an pandas DataFrame representation and the + * grouping key (tuple). + * + * @param functionExpr function called on each group + * @param groupingAttributes used to group the data + * @param outputAttrs used to define the output rows + * @param stateType used to serialize/deserialize state before calling `functionExpr` + * @param outputMode the output mode of `func` + * @param timeout used to timeout groups that have not received data in a while + * @param child logical plan of the underlying data + */ +case class FlatMapGroupsInPandasWithState( + functionExpr: Expression, + groupingAttributes: Seq[Attribute], + outputAttrs: Seq[Attribute], + stateType: StructType, + outputMode: OutputMode, + timeout: GroupStateTimeout, + child: LogicalPlan) extends UnaryNode { + + override def output: Seq[Attribute] = outputAttrs + + override def producedAttributes: AttributeSet = AttributeSet(outputAttrs) + + override protected def withNewChildInternal( + newChild: LogicalPlan): FlatMapGroupsInPandasWithState = copy(child = newChild) +} + trait BaseEvalPython extends UnaryNode { def udfs: Seq[PythonUDF] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 989ee3252187..0429fd27a41e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -30,9 +30,11 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{NumericType, StructType} /** @@ -620,6 +622,49 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + /** + * Applies a grouped vectorized python user-defined function to each group of data. + * The user-defined function defines a transformation: iterator of `pandas.DataFrame` -> + * iterator of `pandas.DataFrame`. + * For each group, all elements in the group are passed as an iterator of `pandas.DataFrame` + * along with corresponding state, and the results for all groups are combined into a new + * [[DataFrame]]. + * + * This function does not support partial aggregation, and requires shuffling all the data in + * the [[DataFrame]]. + * + * This function uses Apache Arrow as serialization format between Java executors and Python + * workers. + */ + private[sql] def applyInPandasWithState( + func: PythonUDF, + outputStructType: StructType, + stateStructType: StructType, + outputModeStr: String, + timeoutConfStr: String): DataFrame = { + val timeoutConf = org.apache.spark.sql.execution.streaming + .GroupStateImpl.groupStateTimeoutFromString(timeoutConfStr) + val outputMode = InternalOutputModes(outputModeStr) + if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { + throw new IllegalArgumentException("The output mode of function should be append or update") + } + val groupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + val groupingAttrs = groupingNamedExpressions.map(_.toAttribute) + val outputAttrs = outputStructType.toAttributes + val plan = FlatMapGroupsInPandasWithState( + func, + groupingAttrs, + outputAttrs, + stateStructType, + outputMode, + timeoutConf, + child = df.logicalPlan) + Dataset.ofRows(df.sparkSession, plan) + } + override def toString: String = { val builder = new StringBuilder builder.append("RelationalGroupedDataset: [grouping expressions: [") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6104104c7bea..c64a123e3a78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -684,6 +684,25 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Strategy to convert [[FlatMapGroupsInPandasWithState]] logical operator to physical operator + * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. + */ + object FlatMapGroupsInPandasWithStateStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case FlatMapGroupsInPandasWithState( + func, groupAttr, outputAttr, stateType, outputMode, timeout, child) => + val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + val execPlan = python.FlatMapGroupsInPandasWithStateExec( + func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout, + batchTimestampMs = None, eventTimeWatermark = None, planLater(child) + ) + execPlan :: Nil + case _ => + Nil + } + } + /** * Strategy to convert EvalPython logical operator to physical operator. */ @@ -793,6 +812,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout, hasInitialState, planLater(initialState), planLater(child) ) :: Nil + case _: FlatMapGroupsInPandasWithState => + // TODO(SPARK-40443): support applyInPandasWithState in batch query + throw new UnsupportedOperationException( + "applyInPandasWithState is unsupported in batch query. Use applyInPandas instead.") case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 7abca5f0e332..2988c0fb5187 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -44,7 +44,7 @@ object ArrowWriter { new ArrowWriter(root, children.toArray) } - private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { + private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { val field = vector.getField() (ArrowUtils.fromArrowField(field), vector) match { case (BooleanType, vector: BitVector) => new BooleanWriter(vector) @@ -98,6 +98,16 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { count += 1 } + def sizeInBytes(): Int = { + var i = 0 + var bytes = 0 + while (i < fields.size) { + bytes += fields(i).getSizeInBytes() + i += 1 + } + bytes + } + def finish(): Unit = { root.setRowCount(count) fields.foreach(_.finish()) @@ -132,6 +142,10 @@ private[arrow] abstract class ArrowFieldWriter { count += 1 } + def getSizeInBytes(): Int = { + valueVector.getBufferSizeFor(count) + } + def finish(): Unit = { valueVector.setValueCount(count) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala new file mode 100644 index 000000000000..bd8c72029dcb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.api.python._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER} +import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} + + +/** + * A variant implementation of [[ArrowPythonRunner]] to serve the operation + * applyInPandasWithState. + * + * Unlike normal ArrowPythonRunner which both input and output (executor <-> python worker) + * are InternalRow, applyInPandasWithState has side data (state information) in both input + * and output along with data, which requires different struct on Arrow RecordBatch. + */ +class ApplyInPandasWithStatePythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + inputSchema: StructType, + override protected val timeZoneId: String, + initialWorkerConf: Map[String, String], + stateEncoder: ExpressionEncoder[Row], + keySchema: StructType, + outputSchema: StructType, + stateValueSchema: StructType) + extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets) + with PythonArrowInput[InType] + with PythonArrowOutput[OutType] { + + private val sqlConf = SQLConf.get + + override protected val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA) + + override val simplifiedTraceback: Boolean = sqlConf.pysparkSimplifiedTraceback + + override val bufferSize: Int = { + val configuredSize = sqlConf.pandasUDFBufferSize + if (configuredSize < 4) { + logWarning("Pandas execution requires more than 4 bytes. Please configure bigger value " + + s"for the configuration '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'. " + + "Force using the value '4'.") + 4 + } else { + configuredSize + } + } + + private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch + + // applyInPandasWithState has its own mechanism to construct the Arrow RecordBatch instance. + // Configurations are both applied to executor and Python worker, set them to the worker conf + // to let Python worker read the config properly. + override protected val workerConf: Map[String, String] = initialWorkerConf + + (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) + + private val stateRowDeserializer = stateEncoder.createDeserializer() + + /** + * This method sends out the additional metadata before sending out actual data. + * + * Specifically, this class overrides this method to also write the schema for state value. + */ + override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { + super.handleMetadataBeforeExec(stream) + // Also write the schema for state value + PythonRDD.writeUTF(stateValueSchema.json, stream) + } + + /** + * Read the (key, state, values) from input iterator and construct Arrow RecordBatches, and + * write constructed RecordBatches to the writer. + * + * See [[ApplyInPandasWithStateWriter]] for more details. + */ + protected def writeIteratorToArrowStream( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[InType]): Unit = { + val w = new ApplyInPandasWithStateWriter(root, writer, arrowMaxRecordsPerBatch) + + while (inputIterator.hasNext) { + val (keyRow, groupState, dataIter) = inputIterator.next() + assert(dataIter.hasNext, "should have at least one data row!") + w.startNewGroup(keyRow, groupState) + + while (dataIter.hasNext) { + val dataRow = dataIter.next() + w.writeRow(dataRow) + } + + w.finalizeGroup() + } + + w.finalizeData() + } + + /** + * Deserialize ColumnarBatch received from the Python worker to produce the output. Schema info + * for given ColumnarBatch is also provided as well. + */ + protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OutType = { + // This should at least have one row for state. Also, we ensure that all columns across + // data and state metadata have same number of rows, which is required by Arrow record + // batch. + assert(batch.numRows() > 0) + assert(schema.length == 2) + + def getColumnarBatchForStructTypeColumn( + batch: ColumnarBatch, + ordinal: Int, + expectedType: StructType): ColumnarBatch = { + // UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(ordinal).asInstanceOf[ArrowColumnVector] + val dataType = schema(ordinal).dataType.asInstanceOf[StructType] + assert(dataType.sameType(expectedType), + s"Schema equality check failure! type from Arrow: $dataType, expected type: $expectedType") + + val outputVectors = dataType.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + + flattenedBatch + } + + def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = { + val dataBatch = getColumnarBatchForStructTypeColumn(batch, 0, outputSchema) + dataBatch.rowIterator.asScala.flatMap { row => + if (row.isNullAt(0)) { + // The entire row in record batch seems to be for state metadata. + None + } else { + Some(row) + } + } + } + + def constructIterForState(batch: ColumnarBatch): Iterator[OutTypeForState] = { + val stateMetadataBatch = getColumnarBatchForStructTypeColumn(batch, 1, + STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER) + + stateMetadataBatch.rowIterator().asScala.flatMap { row => + implicit val formats = org.json4s.DefaultFormats + + if (row.isNullAt(0)) { + // The entire row in record batch seems to be for data. + None + } else { + // NOTE: See ApplyInPandasWithStatePythonRunner.STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER + // for the schema. + val propertiesAsJson = parse(row.getUTF8String(0).toString) + val keyRowAsUnsafeAsBinary = row.getBinary(1) + val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) + keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, keyRowAsUnsafeAsBinary.length) + val maybeObjectRow = if (row.isNullAt(2)) { + None + } else { + val pickledStateValue = row.getBinary(2) + Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema, + stateRowDeserializer)) + } + val oldTimeoutTimestamp = row.getLong(3) + + Some((keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson), + oldTimeoutTimestamp)) + } + } + } + + (constructIterForState(batch), constructIterForData(batch)) + } +} + +object ApplyInPandasWithStatePythonRunner { + type InType = (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow]) + type OutTypeForState = (UnsafeRow, GroupStateImpl[Row], Long) + type OutType = (Iterator[OutTypeForState], Iterator[InternalRow]) + + val STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER: StructType = StructType( + Array( + StructField("properties", StringType), + StructField("keyRowAsUnsafe", BinaryType), + StructField("object", BinaryType), + StructField("oldTimeoutTimestamp", LongType) + ) + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala new file mode 100644 index 000000000000..60a228ddd73a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot} +import org.apache.arrow.vector.ipc.ArrowStreamWriter + +import org.apache.spark.sql.Row +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.types.{BinaryType, BooleanType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * This class abstracts the complexity on constructing Arrow RecordBatches for data and state with + * bin-packing and chunking. The caller only need to call the proper public methods of this class + * `startNewGroup`, `writeRow`, `finalizeGroup`, `finalizeData` and this class will write the data + * and state into Arrow RecordBatches with performing bin-pack and chunk internally. + * + * This class requires that the parameter `root` has been initialized with the Arrow schema like + * below: + * - data fields + * - state field + * - nested schema (Refer ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA) + * + * Please refer the code comment in the implementation to see how the writes of data and state + * against Arrow RecordBatch work with consideration of bin-packing and chunking. + */ +class ApplyInPandasWithStateWriter( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + arrowMaxRecordsPerBatch: Int) { + + import ApplyInPandasWithStateWriter._ + + // Unlike applyInPandas (and other PySpark operators), applyInPandasWithState requires to produce + // the additional data `state`, along with the input data. + // + // ArrowStreamWriter supports only single VectorSchemaRoot, which means all Arrow RecordBatches + // being sent out from ArrowStreamWriter should have same schema. That said, we have to construct + // "an" Arrow schema to contain both data and state, and also construct ArrowBatches to contain + // both data and state. + // + // To achieve this, we extend the schema for input data to have a column for state at the end. + // But also, we logically group the columns by family (data vs state) and initialize writer + // separately, since it's lot more easier and probably performant to write the row directly + // rather than projecting the row to match up with the overall schema. + // + // Although Arrow RecordBatch enables to write the data as columnar, we figure out it gives + // strange outputs if we don't ensure that all columns have the same number of values. Since + // there are at least one data for a grouping key (we ensure this for the case of handling timed + // out state as well) whereas there is only one state for a grouping key, we have to fill up the + // empty rows in state side to ensure both have the same number of rows. + private val arrowWriterForData = createArrowWriter( + root.getFieldVectors.asScala.toSeq.dropRight(1)) + private val arrowWriterForState = createArrowWriter( + root.getFieldVectors.asScala.toSeq.takeRight(1)) + + // - Bin-packing + // + // We apply bin-packing the data from multiple groups into one Arrow RecordBatch to + // gain the performance. In many cases, the amount of data per grouping key is quite + // small, which does not seem to maximize the benefits of using Arrow. + // + // We have to split the record batch down to each group in Python worker to convert the + // data for group to Pandas, but hopefully, Arrow RecordBatch provides the way to split + // the range of data and give a view, say, "zero-copy". To help splitting the range for + // data, we provide the "start offset" and the "number of data" in the state metadata. + // + // We don't bin-pack all groups into a single record batch - we have a limit on the number + // of rows in the current Arrow RecordBatch to stop adding next group. + // + // - Chunking + // + // We also chunk the data from single group into multiple Arrow RecordBatch to ensure + // scalability. Note that we don't know the volume (number of rows, overall size) of data for + // specific group key before we read the entire data. The easiest approach to address both + // bin-pack and chunk is to check the number of rows in the current Arrow RecordBatch for each + // write of row. + // + // - Data and State + // + // Since we apply bin-packing and chunking, there should be the way to distinguish each chunk + // from the entire data part of Arrow RecordBatch. We leverage the state metadata to also + // contain the "metadata" of data part to distinguish the chunk from the entire data. + // As a result, state metadata has a 1-1 relationship with "chunk", instead of "grouping key". + // + // - Consideration + // + // Since the number of rows in Arrow RecordBatch does not represent the actual size (bytes), + // the limit should be set very conservatively. Using a small number of limit does not introduce + // correctness issues. + + // variables for tracking current grouping key and state + private var currentGroupKeyRow: UnsafeRow = _ + private var currentGroupState: GroupStateImpl[Row] = _ + + // variables for tracking the status of current batch + private var totalNumRowsForBatch = 0 + private var totalNumStatesForBatch = 0 + + // variables for tracking the status of current chunk + private var startOffsetForCurrentChunk = 0 + private var numRowsForCurrentChunk = 0 + + + /** + * Indicates writer to start with new grouping key. + * + * @param keyRow The grouping key row for current group. + * @param groupState The instance of GroupStateImpl for current group. + */ + def startNewGroup(keyRow: UnsafeRow, groupState: GroupStateImpl[Row]): Unit = { + currentGroupKeyRow = keyRow + currentGroupState = groupState + } + + /** + * Indicates writer to write a row in the current group. + * + * @param dataRow The row to write in the current group. + */ + def writeRow(dataRow: InternalRow): Unit = { + // If it exceeds the condition of batch (number of records) and there is more data for the + // same group, finalize and construct a new batch. + + if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) { + finalizeCurrentChunk(isLastChunkForGroup = false) + finalizeCurrentArrowBatch() + } + + arrowWriterForData.write(dataRow) + + numRowsForCurrentChunk += 1 + totalNumRowsForBatch += 1 + } + + /** + * Indicates writer that current group has finalized and there will be no further row bound to + * the current group. + */ + def finalizeGroup(): Unit = { + finalizeCurrentChunk(isLastChunkForGroup = true) + + // If it exceeds the condition of batch (number of records) once the all data is received for + // same group, finalize and construct a new batch. + if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) { + finalizeCurrentArrowBatch() + } + } + + /** + * Indicates writer that all groups have been processed. + */ + def finalizeData(): Unit = { + if (totalNumRowsForBatch > 0) { + // We still have some rows in the current record batch. Need to finalize them as well. + finalizeCurrentArrowBatch() + } + } + + private def createArrowWriter(fieldVectors: Seq[FieldVector]): ArrowWriter = { + val children = fieldVectors.map { vector => + vector.allocateNew() + createFieldWriter(vector) + } + + new ArrowWriter(root, children.toArray) + } + + private def buildStateInfoRow( + keyRow: UnsafeRow, + groupState: GroupStateImpl[Row], + startOffset: Int, + numRows: Int, + isLastChunk: Boolean): InternalRow = { + // NOTE: see ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA + val stateUnderlyingRow = new GenericInternalRow( + Array[Any]( + UTF8String.fromString(groupState.json()), + keyRow.getBytes, + groupState.getOption.map(PythonSQLUtils.toPyRow).orNull, + startOffset, + numRows, + isLastChunk + ) + ) + new GenericInternalRow(Array[Any](stateUnderlyingRow)) + } + + private def finalizeCurrentChunk(isLastChunkForGroup: Boolean): Unit = { + val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState, + startOffsetForCurrentChunk, numRowsForCurrentChunk, isLastChunkForGroup) + arrowWriterForState.write(stateInfoRow) + totalNumStatesForBatch += 1 + + // The start offset for next chunk would be same as the total number of rows for batch, + // unless the next chunk starts with new batch. + startOffsetForCurrentChunk = totalNumRowsForBatch + numRowsForCurrentChunk = 0 + } + + private def finalizeCurrentArrowBatch(): Unit = { + val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch + (0 until remainingEmptyStateRows).foreach { _ => + arrowWriterForState.write(EMPTY_STATE_METADATA_ROW) + } + + arrowWriterForState.finish() + arrowWriterForData.finish() + writer.writeBatch() + arrowWriterForState.reset() + arrowWriterForData.reset() + + startOffsetForCurrentChunk = 0 + numRowsForCurrentChunk = 0 + totalNumRowsForBatch = 0 + totalNumStatesForBatch = 0 + } +} + +object ApplyInPandasWithStateWriter { + // This schema contains both state metadata and the metadata of the chunk. Refer the code comment + // of "Data and State" for more details. + val STATE_METADATA_SCHEMA: StructType = StructType( + Array( + /* + Metadata of the state + */ + + // properties of state instance (excluding state value) in json format + StructField("properties", StringType), + // key row as UnsafeRow, Python worker won't touch this value but send the value back to + // executor when sending an update of state + StructField("keyRowAsUnsafe", BinaryType), + // state value + StructField("object", BinaryType), + + /* + Metadata of the chunk + */ + + // start offset of the data chunk from entire data + StructField("startOffset", IntegerType), + // the number of rows for the data chunk + StructField("numRows", IntegerType), + // whether the current data chunk is the last one for current grouping key or not + StructField("isLastChunk", BooleanType) + ) + ) + + // To avoid initializing a new row for empty state metadata row. + val EMPTY_STATE_METADATA_ROW = new GenericInternalRow( + Array[Any](null, null, null, null, null, null)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index e830ea6b5466..b39787b12a48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -78,8 +78,8 @@ case class FlatMapCoGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { - val (leftDedup, leftArgOffsets) = resolveArgOffsets(left, leftGroup) - val (rightDedup, rightArgOffsets) = resolveArgOffsets(right, rightGroup) + val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup) + val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, rightGroup) // Map cogrouped rows to ArrowPythonRunner results, Only execute if partition is not empty left.execute().zipPartitions(right.execute()) { (leftData, rightData) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 3a3a6022f998..f0e815e966e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -75,7 +75,7 @@ case class FlatMapGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes) + val (dedupAttributes, argOffsets) = resolveArgOffsets(child.output, groupingAttributes) // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala new file mode 100644 index 000000000000..159f805f7347 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.plans.physical.Distribution +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.CompletionIterator + +/** + * Physical operator for executing + * [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]] + * + * @param functionExpr function called on each group + * @param groupingAttributes used to group the data + * @param outAttributes used to define the output rows + * @param stateType used to serialize/deserialize state before calling `functionExpr` + * @param stateInfo `StatefulOperatorStateInfo` to identify the state store for a given operator. + * @param stateFormatVersion the version of state format. + * @param outputMode the output mode of `functionExpr` + * @param timeoutConf used to timeout groups that have not received data in a while + * @param batchTimestampMs processing timestamp of the current batch. + * @param eventTimeWatermark event time watermark for the current batch + * @param child logical plan of the underlying data + */ +case class FlatMapGroupsInPandasWithStateExec( + functionExpr: Expression, + groupingAttributes: Seq[Attribute], + outAttributes: Seq[Attribute], + stateType: StructType, + stateInfo: Option[StatefulOperatorStateInfo], + stateFormatVersion: Int, + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + batchTimestampMs: Option[Long], + eventTimeWatermark: Option[Long], + child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { + + // TODO(SPARK-40444): Add the support of initial state. + override protected val initialStateDeserializer: Expression = null + override protected val initialStateGroupAttrs: Seq[Attribute] = null + override protected val initialStateDataAttrs: Seq[Attribute] = null + override protected val initialState: SparkPlan = null + override protected val hasInitialState: Boolean = false + + override protected val stateEncoder: ExpressionEncoder[Any] = + RowEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]] + + override def output: Seq[Attribute] = outAttributes + + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + + private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func + private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) + private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets( + groupingAttributes ++ child.output, groupingAttributes) + private lazy val unsafeProj = UnsafeProjection.create(dedupAttributes, child.output) + + override def requiredChildDistribution: Seq[Distribution] = + StatefulOperatorPartitioning.getCompatibleDistribution( + groupingAttributes, getStateInfo, conf) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( + groupingAttributes.map(SortOrder(_, Ascending))) + + override def shortName: String = "applyInPandasWithState" + + override protected def withNewChildInternal( + newChild: SparkPlan): FlatMapGroupsInPandasWithStateExec = copy(child = newChild) + + override def createInputProcessor( + store: StateStore): InputProcessor = new InputProcessor(store: StateStore) { + + override def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) + val processIter = groupedIter.map { case (keyRow, valueRowIter) => + val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] + val stateData = stateManager.getState(store, keyUnsafeRow) + (keyUnsafeRow, stateData, valueRowIter.map(unsafeProj)) + } + + process(processIter, hasTimedOut = false) + } + + override def processNewDataWithInitialState( + childDataIter: Iterator[InternalRow], + initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = { + throw new UnsupportedOperationException("Should not reach here!") + } + + override def processTimedOutState(): Iterator[InternalRow] = { + if (isTimeoutEnabled) { + val timeoutThreshold = timeoutConf match { + case ProcessingTimeTimeout => batchTimestampMs.get + case EventTimeTimeout => eventTimeWatermark.get + case _ => + throw new IllegalStateException( + s"Cannot filter timed out keys for $timeoutConf") + } + val timingOutPairs = stateManager.getAllState(store).filter { state => + state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold + } + + val processIter = timingOutPairs.map { stateData => + val joinedKeyRow = unsafeProj( + new JoinedRow( + stateData.keyRow, + new GenericInternalRow(Array.fill(dedupAttributes.length)(null: Any)))) + + (stateData.keyRow, stateData, Iterator.single(joinedKeyRow)) + } + + process(processIter, hasTimedOut = true) + } else Iterator.empty + } + + private def process( + iter: Iterator[(UnsafeRow, StateData, Iterator[InternalRow])], + hasTimedOut: Boolean): Iterator[InternalRow] = { + val runner = new ApplyInPandasWithStatePythonRunner( + chainedFunc, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + Array(argOffsets), + StructType.fromAttributes(dedupAttributes), + sessionLocalTimeZone, + pythonRunnerConf, + stateEncoder.asInstanceOf[ExpressionEncoder[Row]], + groupingAttributes.toStructType, + outAttributes.toStructType, + stateType) + + val context = TaskContext.get() + + val processIter = iter.map { case (keyRow, stateData, valueIter) => + val groupedState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r }, + batchTimestampMs.getOrElse(NO_TIMESTAMP), + eventTimeWatermark.getOrElse(NO_TIMESTAMP), + timeoutConf, + hasTimedOut = hasTimedOut, + watermarkPresent).asInstanceOf[GroupStateImpl[Row]] + (keyRow, groupedState, valueIter) + } + runner.compute(processIter, context.partitionId(), context).flatMap { + case (stateIter, outputIter) => + // When the iterator is consumed, then write changes to state. + // state does not affect each others, hence when to update does not affect to the result. + def onIteratorCompletion: Unit = { + stateIter.foreach { case (keyRow, newGroupState, oldTimeoutTimestamp) => + if (newGroupState.isRemoved && !newGroupState.getTimeoutTimestampMs.isPresent()) { + stateManager.removeState(store, keyRow) + numRemovedStateRows += 1 + } else { + val currentTimeoutTimestamp = newGroupState.getTimeoutTimestampMs + .orElse(NO_TIMESTAMP) + val hasTimeoutChanged = currentTimeoutTimestamp != oldTimeoutTimestamp + val shouldWriteState = newGroupState.isUpdated || newGroupState.isRemoved || + hasTimeoutChanged + + if (shouldWriteState) { + val updatedStateObj = if (newGroupState.exists) newGroupState.get else null + stateManager.putState(store, keyRow, updatedStateObj, + currentTimeoutTimestamp) + numUpdatedStateRows += 1 + } + } + } + } + + CompletionIterator[InternalRow, Iterator[InternalRow]]( + outputIter, onIteratorCompletion).map { row => + numOutputRows += 1 + row + } + } + } + + override protected def callFunctionAndUpdateState( + stateData: StateData, + valueRowIter: Iterator[InternalRow], + hasTimedOut: Boolean): Iterator[InternalRow] = { + throw new UnsupportedOperationException("Should not reach here!") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala index 2da0000dad4e..078876664062 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala @@ -24,7 +24,7 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.BasePythonRunner import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} -import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan} +import org.apache.spark.sql.execution.GroupedIterator import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** @@ -88,9 +88,10 @@ private[python] object PandasGroupUtils { * argOffsets[argOffsets[0]+2 .. ] is the arg offsets for data attributes */ def resolveArgOffsets( - child: SparkPlan, groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = { + attributes: Seq[Attribute], + groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = { - val dataAttributes = child.output.drop(groupingAttributes.length) + val dataAttributes = attributes.drop(groupingAttributes.length) val groupingIndicesInData = groupingAttributes.map { attribute => dataAttributes.indexWhere(attribute.semanticEquals) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 6168d0f867ad..bf66791183ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -76,7 +76,6 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => val root = VectorSchemaRoot.create(arrowSchema, allocator) Utils.tryWithSafeFinally { - val arrowWriter = ArrowWriter.create(root) val writer = new ArrowStreamWriter(root, null, dataOut) writer.start() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 3f369ac5e973..f386282a0b3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.execution.{LocalLimitExec, QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, MergingSessionsExec, ObjectHashAggregateExec, SortAggregateExec, UpdatingSessionsExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike +import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode @@ -62,6 +63,7 @@ class IncrementalExecution( StreamingJoinStrategy :: StatefulAggregationStrategy :: FlatMapGroupsWithStateStrategy :: + FlatMapGroupsInPandasWithStateStrategy :: StreamingRelationStrategy :: StreamingDeduplicationStrategy :: StreamingGlobalLimitStrategy(outputMode) :: Nil @@ -210,6 +212,13 @@ class IncrementalExecution( hasInitialState = hasInitialState ) + case m: FlatMapGroupsInPandasWithStateExec => + m.copy( + stateInfo = Some(nextStatefulOperationStateInfo), + batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), + eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs) + ) + case j: StreamingSymmetricHashJoinExec => j.copy( stateInfo = Some(nextStatefulOperationStateInfo), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 01ff72bac7bc..022fd1239ce4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -49,7 +49,7 @@ package object state { } /** Map each partition of an RDD along with data in a [[StateStore]]. */ - private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( + def mapPartitionsWithStateStore[U: ClassTag]( stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType,