diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 5a13674e8bfbf..7b31fa93c32e5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -53,6 +53,7 @@ private[spark] object PythonEvalType { val SQL_MAP_PANDAS_ITER_UDF = 205 val SQL_COGROUPED_MAP_PANDAS_UDF = 206 val SQL_MAP_ARROW_ITER_UDF = 207 + val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -65,6 +66,7 @@ private[spark] object PythonEvalType { case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF" case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF" case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF" + case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE" } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index aef79c7882ca1..484a07c18ed0e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -3051,7 +3051,7 @@ private[spark] object Utils extends Logging { * and return the trailing part after the last dollar sign in the middle */ @scala.annotation.tailrec - private def stripDollars(s: String): String = { + def stripDollars(s: String): String = { val lastDollarIndex = s.lastIndexOf('$') if (lastDollarIndex < s.length - 1) { // The last char is not a dollar sign diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index d776410fd2c32..4d8e604a655d0 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -450,6 +450,7 @@ def __hash__(self): "pyspark.sql.tests.test_group", "pyspark.sql.tests.test_pandas_cogrouped_map", "pyspark.sql.tests.test_pandas_grouped_map", + "pyspark.sql.tests.test_pandas_grouped_map_with_state", "pyspark.sql.tests.test_pandas_map", "pyspark.sql.tests.test_arrow_map", "pyspark.sql.tests.test_pandas_udf", diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a8d1fd03b219c..c9eaba53320db 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -105,6 +105,7 @@ PandasMapIterUDFType, PandasCogroupedMapUDFType, ArrowMapIterUDFType, + PandasGroupedMapUDFWithStateType, ) from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import AtomicType, StructType @@ -147,6 +148,7 @@ class PythonEvalType: SQL_MAP_PANDAS_ITER_UDF: "PandasMapIterUDFType" = 205 SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206 SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207 + SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208 def portable_hash(x: Hashable) -> int: diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi index 27ac64a7238ba..82b861c51cf5c 100644 --- a/python/pyspark/sql/pandas/_typing/__init__.pyi +++ b/python/pyspark/sql/pandas/_typing/__init__.pyi @@ -22,19 +22,19 @@ from typing import ( Iterable, NewType, Tuple, - Type, TypeVar, Union, ) from typing_extensions import Protocol, Literal from types import FunctionType -from pyspark.sql._typing import LiteralType +import pyarrow from pandas.core.frame import DataFrame as PandasDataFrame from pandas.core.series import Series as PandasSeries from numpy import ndarray as NDArray -import pyarrow +from pyspark.sql._typing import LiteralType +from pyspark.sql.streaming.state import GroupStateImpl ArrayLike = NDArray DataFrameLike = PandasDataFrame @@ -51,6 +51,7 @@ PandasScalarIterUDFType = Literal[204] PandasMapIterUDFType = Literal[205] PandasCogroupedMapUDFType = Literal[206] ArrowMapIterUDFType = Literal[207] +PandasGroupedMapUDFWithStateType = Literal[208] class PandasVariadicScalarToScalarFunction(Protocol): def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ... @@ -253,9 +254,11 @@ PandasScalarIterFunction = Union[ PandasGroupedMapFunction = Union[ Callable[[DataFrameLike], DataFrameLike], - Callable[[Any, DataFrameLike], DataFrameLike], + Callable[[Tuple, DataFrameLike], DataFrameLike], ] +PandasGroupedMapFunctionWithState = Callable[[Tuple, DataFrameLike, GroupStateImpl], DataFrameLike] + class PandasVariadicGroupedAggFunction(Protocol): def __call__(self, *_: SeriesLike) -> LiteralType: ... diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 94fabdbb29590..1c6c2219edcec 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -369,6 +369,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, None, ]: # None means it should infer the type from type hints. @@ -399,6 +400,7 @@ def _create_pandas_udf(f, returnType, evalType): ) elif evalType in [ PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 6178433573e9e..948fe5ce71355 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -15,18 +15,20 @@ # limitations under the License. # import sys -from typing import List, Union, TYPE_CHECKING +from typing import List, Union, TYPE_CHECKING, cast import warnings from pyspark.rdd import PythonEvalType from pyspark.sql.column import Column from pyspark.sql.dataframe import DataFrame -from pyspark.sql.types import StructType +from pyspark.sql.streaming.state import GroupStateTimeout +from pyspark.sql.types import StructType, _parse_datatype_string if TYPE_CHECKING: from pyspark.sql.pandas._typing import ( GroupedMapPandasUserDefinedFunction, PandasGroupedMapFunction, + PandasGroupedMapFunctionWithState, PandasCogroupedMapFunction, ) from pyspark.sql.group import GroupedData @@ -216,6 +218,45 @@ def applyInPandas( jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) return DataFrame(jdf, self.session) + def applyInPandasWithState( + self, + func: "PandasGroupedMapFunctionWithState", + outputStructType: Union[StructType, str], + stateStructType: Union[StructType, str], + outputMode: str, + timeoutConf: str, + ) -> DataFrame: + from pyspark.sql import GroupedData + from pyspark.sql.functions import pandas_udf + + assert isinstance(self, GroupedData) + assert timeoutConf in [ + GroupStateTimeout.NoTimeout, + GroupStateTimeout.ProcessingTimeTimeout, + GroupStateTimeout.EventTimeTimeout, + ] + + if isinstance(outputStructType, str): + outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) + if isinstance(stateStructType, str): + stateStructType = cast(StructType, _parse_datatype_string(stateStructType)) + + udf = pandas_udf( + func, # type: ignore[call-overload] + returnType=outputStructType, + functionType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + ) + df = self._df + udf_column = udf(*[df[col] for col in df.columns]) + jdf = self._jgd.applyInPandasWithState( + udf_column._jc.expr(), + self.session._jsparkSession.parseDataType(outputStructType.json()), + self.session._jsparkSession.parseDataType(stateStructType.json()), + outputMode, + timeoutConf, + ) + return DataFrame(jdf, self.session) + def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps": """ Cogroups this group with another group so that we can run cogrouped operations. diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py new file mode 100644 index 0000000000000..6281dbadba61b --- /dev/null +++ b/python/pyspark/sql/streaming/state.py @@ -0,0 +1,182 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import datetime +import json +from typing import Tuple, Optional + +from pyspark.sql.types import DateType, Row, StructType + +__all__ = ["GroupStateImpl", "GroupStateTimeout"] + + +class GroupStateTimeout: + NoTimeout: str = "NoTimeout" + ProcessingTimeTimeout: str = "ProcessingTimeTimeout" + EventTimeTimeout: str = "EventTimeTimeout" + + +class GroupStateImpl: + NO_TIMESTAMP: int = -1 + + def __init__( + self, + # JVM Constructor + optionalValue: Row, + batchProcessingTimeMs: int, + eventTimeWatermarkMs: int, + timeoutConf: str, + hasTimedOut: bool, + watermarkPresent: bool, + # JVM internal state. + defined: bool, + updated: bool, + removed: bool, + timeoutTimestamp: int, + # Python internal state. + keySchema: StructType, + ) -> None: + self._value = optionalValue + self._batch_processing_time_ms = batchProcessingTimeMs + self._event_time_watermark_ms = eventTimeWatermarkMs + + assert timeoutConf in [ + GroupStateTimeout.NoTimeout, + GroupStateTimeout.ProcessingTimeTimeout, + GroupStateTimeout.EventTimeTimeout, + ] + self._timeout_conf = timeoutConf + + self._has_timed_out = hasTimedOut + self._watermark_present = watermarkPresent + + self._defined = defined + self._updated = updated + self._removed = removed + self._timeout_timestamp = timeoutTimestamp + + self._key_schema = keySchema + + @property + def exists(self) -> bool: + return self._defined + + @property + def get(self) -> Tuple: + if self.exists: + return tuple(self._value) + else: + raise ValueError("State is either not defined or has already been removed") + + @property + def getOption(self) -> Optional[Tuple]: + if self.exists: + return tuple(self._value) + else: + return None + + @property + def hasTimedOut(self) -> bool: + return self._has_timed_out + + def update(self, newValue: Tuple) -> None: + if newValue is None: + raise ValueError("'None' is not a valid state value") + + self._value = Row(*newValue) + self._defined = True + self._updated = True + self._removed = False + + def remove(self) -> None: + self._defined = False + self._updated = False + self._removed = True + + def setTimeoutDuration(self, durationMs: int) -> None: + if isinstance(durationMs, str): + # TODO(SPARK-XXXXX): Support string representation of durationMs. + raise ValueError("durationMs should be int but get :%s" % type(durationMs)) + + if self._timeout_conf != GroupStateTimeout.ProcessingTimeTimeout: + raise RuntimeError( + "Cannot set timeout duration without enabling processing time timeout in " + "applyInPandasWithState" + ) + + if durationMs <= 0: + raise ValueError("Timeout duration must be positive") + self._timeout_timestamp = durationMs + self._batch_processing_time_ms + + # TODO(SPARK-XXXXX): Implement additionalDuration parameter. + def setTimeoutTimestamp(self, timestampMs: int) -> None: + if self._timeout_conf != GroupStateTimeout.EventTimeTimeout: + raise RuntimeError( + "Cannot set timeout duration without enabling processing time timeout in " + "applyInPandasWithState" + ) + + if isinstance(timestampMs, datetime.datetime): + timestampMs = DateType().toInternal(timestampMs) + + if timestampMs <= 0: + raise ValueError("Timeout timestamp must be positive") + + if ( + self._event_time_watermark_ms != GroupStateImpl.NO_TIMESTAMP + and timestampMs < self._event_time_watermark_ms + ): + raise ValueError( + "Timeout timestamp (%s) cannot be earlier than the " + "current watermark (%s)" % (timestampMs, self._event_time_watermark_ms) + ) + + self._timeout_timestamp = timestampMs + + def getCurrentWatermarkMs(self) -> int: + if not self._watermark_present: + raise RuntimeError( + "Cannot get event time watermark timestamp without setting watermark before " + "applyInPandasWithState" + ) + return self._event_time_watermark_ms + + def getCurrentProcessingTimeMs(self) -> int: + return self._batch_processing_time_ms + + def __str__(self) -> str: + if self.exists: + return "GroupState(%s)" % self.get + else: + return "GroupState()" + + def json(self) -> str: + return json.dumps( + { + # Constructor + "optionalValue": None, # Note that optionalValue will be manually serialized. + "batchProcessingTimeMs": self._batch_processing_time_ms, + "eventTimeWatermarkMs": self._event_time_watermark_ms, + "timeoutConf": self._timeout_conf, + "hasTimedOut": self._has_timed_out, + "watermarkPresent": self._watermark_present, + # JVM internal state. + "defined": self._defined, + "updated": self._updated, + "removed": self._removed, + "timeoutTimestamp": self._timeout_timestamp, + } + ) diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py b/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py new file mode 100644 index 0000000000000..a9a56c557fabd --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py @@ -0,0 +1,97 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from typing import cast + +from pyspark.sql.streaming.state import GroupStateTimeout, GroupStateImpl +from pyspark.sql.types import ( + LongType, + StringType, + StructType, + StructField, + Row, +) +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) + +if have_pandas: + import pandas as pd + +if have_pyarrow: + import pyarrow as pa # noqa: F401 + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + cast(str, pandas_requirement_message or pyarrow_requirement_message), +) +class GroupedMapInPandasWithStateTests(ReusedSQLTestCase): + def test_apply_in_pandas_with_state_basic(self): + df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") + + for q in self.spark.streams.active: + q.stop() + self.assertTrue(df.isStreaming) + + output_type = StructType( + [StructField("key", StringType()), StructField("countAsString", StringType())] + ) + state_type = StructType([StructField("c", LongType())]) + + def func(key, pdf, state): + assert isinstance(state, GroupStateImpl) + state.update((len(pdf),)) + assert state.get[0] == 1 + return pd.DataFrame({"key": [key[0]], "countAsString": [str(len(pdf))]}) + + def check_results(batch_df, _): + self.assertEqual( + set(batch_df.collect()), + {Row(key="hello", countAsString="1"), Row(key="this", countAsString="1")}, + ) + + q = ( + df.groupBy(df["value"]) + .applyInPandasWithState( + func, output_type, state_type, "Update", GroupStateTimeout.NoTimeout + ) + .writeStream.queryName("this_query") + .foreachBatch(check_results) + .start() + ) + + self.assertEqual(q.name, "this_query") + self.assertTrue(q.isActive) + q.processAllAvailable() + + +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_grouped_map_with_state import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 6a01e399d0400..417896ab738c7 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -144,20 +144,23 @@ def returnType(self) -> DataType: "Invalid return type with scalar Pandas UDFs: %s is " "not supported" % str(self._returnType_placeholder) ) - elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + elif ( + self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF + or self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE + ): if isinstance(self._returnType_placeholder, StructType): try: to_arrow_type(self._returnType_placeholder) except TypeError: raise NotImplementedError( "Invalid return type with grouped map Pandas UDFs or " - "at groupby.applyInPandas: %s is not supported" + "at groupby.applyInPandas(withState): %s is not supported" % str(self._returnType_placeholder) ) else: raise TypeError( "Invalid return type for grouped map Pandas " - "UDFs or at groupby.applyInPandas: return type must be a " + "UDFs or at groupby.applyInPandas(withState): return type must be a " "StructType." ) elif ( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index c486b7bed1d81..bdedc88e92a37 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,6 +23,7 @@ import time from inspect import currentframe, getframeinfo, getfullargspec import importlib +import json # 'resource' is a Unix specific module. has_resource_module = True @@ -62,6 +63,7 @@ from pyspark.sql.types import StructType from pyspark.util import fail_on_stopiteration, try_simplify_traceback from pyspark import shuffle +from pyspark.sql.streaming.state import GroupStateImpl pickleSer = CPickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -207,6 +209,37 @@ def wrapped(key_series, value_series): return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] +def wrap_grouped_map_pandas_udf_with_state(f, return_type, state): + def wrapped(key_series, value_series): + import pandas as pd + + key = tuple(s[0] for s in key_series) + if state.hasTimedOut: + # Timeout processing pass empty iterator. Here we return an empty DataFrame instead. + result = f(key, pd.DataFrame(columns=pd.concat(value_series, axis=1).columns), state) + else: + result = f(key, pd.concat(value_series, axis=1), state) + + if not isinstance(result, pd.DataFrame): + raise TypeError( + "Return type of the user-defined function should be " + "pandas.DataFrame, but is {}".format(type(result)) + ) + # the number of columns of result have to match the return type + # but it is fine for result to have no columns at all if it is empty + if not ( + len(result.columns) == len(return_type) or len(result.columns) == 0 and result.empty + ): + raise RuntimeError( + "Number of columns of the returned pandas.DataFrame " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(return_type), len(result.columns)) + ) + return result + + return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] + + def wrap_grouped_agg_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) @@ -281,7 +314,7 @@ def wrapped(begin_index, end_index, *series): return lambda *a: (wrapped(*a), arrow_return_type) -def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): +def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, state=None): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] chained_func = None @@ -311,6 +344,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type, state) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec) @@ -327,6 +362,9 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): def read_udfs(pickleSer, infile, eval_type): runner_conf = {} + # Used for state support in Structured Streaming. + state = None + if eval_type in ( PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, @@ -336,6 +374,7 @@ def read_udfs(pickleSer, infile, eval_type): PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, ): # Load conf used for pandas_udf evaluation @@ -345,6 +384,21 @@ def read_udfs(pickleSer, infile, eval_type): v = utf8_deserializer.loads(infile) runner_conf[k] = v + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + # 1. State properties + properties = json.loads(utf8_deserializer.loads(infile)) + + # 2. State key + length = read_int(infile) + row = None + if length > 0: + row = pickleSer.loads(infile.read(length)) + # 3. Schema for state key + key_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) + properties["optionalValue"] = row + + state = GroupStateImpl(keySchema=key_schema, **properties) + # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = ( @@ -438,7 +492,7 @@ def map_batch(batch): ) # profiling is not supported for UDF - return func, None, ser, ser + return func, None, ser, ser, state def extract_key_value_indexes(grouped_arg_offsets): """ @@ -469,14 +523,19 @@ def extract_key_value_indexes(grouped_arg_offsets): idx += offsets_len return parsed - if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + if eval_type in ( + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + ): # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 - # See FlatMapGroupsInPandasExec for how arg_offsets are used to + # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to # distinguish between grouping attributes and data attributes - arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + arg_offsets, f = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0, state=state + ) parsed_offsets = extract_key_value_indexes(arg_offsets) # Create function like this: @@ -519,11 +578,12 @@ def func(_, it): return map(mapper, it) # profiling is not supported for UDF - return func, None, ser, ser + return func, None, ser, ser, state def main(infile, outfile): faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None) + state = None try: if faulthandler_log_path: faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid())) @@ -585,6 +645,7 @@ def main(infile, outfile): ) # initialize global state + state = None taskContext = None if isBarrier: taskContext = BarrierTaskContext._getOrCreate() @@ -667,7 +728,9 @@ def main(infile, outfile): if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) else: - func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) + func, profiler, deserializer, serializer, state = read_udfs( + pickleSer, infile, eval_type + ) init_time = time.time() @@ -722,6 +785,18 @@ def process(): # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + + # Send GroupState back to JVM if exists. + if state is not None: + # 1. Send JSON-serialized GroupState + write_with_length(state.json().encode("utf-8"), outfile) + + # 2. Send pickled Row. + if state._value is None: + write_int(0, outfile) + else: + write_with_length(pickleSer.dumps(state._key_schema.toInternal(state._value)), outfile) + write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java index a814525f870c9..479a097713f51 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java @@ -32,6 +32,9 @@ @Experimental @Evolving public class GroupStateTimeout { + // NOTE: if you're adding new type of timeout, you should also fix the places below: + // - Scala: org.apache.spark.sql.api.python.PythonSQLUtils.getGroupStateTimeoutFromString + // - Python: pyspark.sql.streaming.state.GroupStateTimeout /** * Timeout based on processing time. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index c2f74b3508342..67d072bc36824 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.types.StructType /** * FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame. @@ -98,6 +100,43 @@ case class FlatMapCoGroupsInPandas( copy(left = newLeft, right = newRight) } +/** + * Similar with [[FlatMapGroupsWithState]]. Applies func to each unique group + * in `child`, based on the evaluation of `groupingAttributes`, + * while using state data. + * `functionExpr` is invoked with an pandas DataFrame representation and the + * grouping key (tuple). + * + * @param functionExpr function called on each group + * @param groupingAttributes used to group the data + * @param outputAttrs used to define the output rows + * @param stateType used to serialize/deserialize state before calling `functionExpr` + * @param outputMode the output mode of `func` + * @param isMapGroupsWithState whether it is created by the `mapGroupsWithState` method + * @param timeout used to timeout groups that have not received data in a while + * @param child logical plan of the underlying data + */ +case class FlatMapGroupsInPandasWithState( + functionExpr: Expression, + groupingAttributes: Seq[Attribute], + outputAttrs: Seq[Attribute], + stateType: StructType, + outputMode: OutputMode, + isMapGroupsWithState: Boolean = false, + timeout: GroupStateTimeout, + child: LogicalPlan) extends UnaryNode { + if (isMapGroupsWithState) { + assert(outputMode == OutputMode.Update) + } + + override def output: Seq[Attribute] = outputAttrs + + override def producedAttributes: AttributeSet = AttributeSet(outputAttrs) + + override protected def withNewChildInternal( + newChild: LogicalPlan): FlatMapGroupsInPandasWithState = copy(child = newChild) +} + trait BaseEvalPython extends UnaryNode { def udfs: Seq[PythonUDF] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 989ee32521871..6c7b14b2334cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -30,9 +30,11 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{NumericType, StructType} /** @@ -620,6 +622,36 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + private[sql] def applyInPandasWithState( + func: PythonUDF, + outputStructType: StructType, + stateStructType: StructType, + outputModeStr: String, + timeoutConfStr: String): DataFrame = { + val timeoutConf = org.apache.spark.sql.execution.streaming + .GroupStateImpl.groupStateTimeoutFromString(timeoutConfStr) + val outputMode = InternalOutputModes(outputModeStr) + if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { + throw new IllegalArgumentException("The output mode of function should be append or update") + } + val groupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + val groupingAttrs = groupingNamedExpressions.map(_.toAttribute) + val outputAttrs = outputStructType.toAttributes + val plan = FlatMapGroupsInPandasWithState( + func, + groupingAttrs, + outputAttrs, + stateStructType, + outputMode, + isMapGroupsWithState = false, + timeoutConf, + child = df.logicalPlan) + Dataset.ofRows(df.sparkSession, plan) + } + override def toString: String = { val builder = new StringBuilder builder.append("RelationalGroupedDataset: [grouping expressions: [") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index a3ba863623398..c8594b9c49b89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -22,14 +22,15 @@ import java.net.Socket import java.nio.channels.Channels import java.util.Locale -import net.razorvine.pickle.Pickler +import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.api.python.DechunkedInputStream import org.apache.spark.internal.Logging import org.apache.spark.security.SocketAuthServer import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -37,12 +38,29 @@ import org.apache.spark.sql.execution.{ExplainMode, QueryExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} private[sql] object PythonSQLUtils extends Logging { - private lazy val internalRowPickler = { + private def withInternalRowPickler(f: Pickler => Array[Byte]): Array[Byte] = { EvaluatePython.registerPicklers() - new Pickler(true, false) + val pickler = new Pickler(true, false) + val ret = try { + f(pickler) + } finally { + pickler.close() + } + ret + } + + private def withInternalRowUnpickler(f: Unpickler => Any): Any = { + EvaluatePython.registerPicklers() + val unpickler = new Unpickler + val ret = try { + f(unpickler) + } finally { + unpickler.close() + } + ret } def parseDataType(typeText: String): DataType = CatalystSqlParser.parseDataType(typeText) @@ -94,8 +112,18 @@ private[sql] object PythonSQLUtils extends Logging { def toPyRow(row: Row): Array[Byte] = { assert(row.isInstanceOf[GenericRowWithSchema]) - internalRowPickler.dumps(EvaluatePython.toJava( - CatalystTypeConverters.convertToCatalyst(row), row.schema)) + withInternalRowPickler(_.dumps(EvaluatePython.toJava( + CatalystTypeConverters.convertToCatalyst(row), row.schema))) + } + + def toJVMRow( + arr: Array[Byte], + returnType: StructType, + deserializer: ExpressionEncoder.Deserializer[Row]): Row = { + val fromJava = EvaluatePython.makeFromJava(returnType) + val internalRow = + fromJava(withInternalRowUnpickler(_.loads(arr))).asInstanceOf[InternalRow] + deserializer(internalRow) } def castTimestampNTZToLong(c: Column): Column = Column(CastTimestampNTZToLong(c.expr)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6104104c7bea4..7ec47f469adde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -684,6 +684,25 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Strategy to convert [[FlatMapGroupsInPandasWithState]] logical operator to physical operator + * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. + */ + object FlatMapGroupsInPandasWithStateStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case FlatMapGroupsInPandasWithState( + func, groupAttr, outputAttr, stateType, outputMode, _, timeout, child) => + val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + val execPlan = python.FlatMapGroupsInPandasWithStateExec( + func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout, + batchTimestampMs = None, eventTimeWatermark = None, planLater(child) + ) + execPlan :: Nil + case _ => + Nil + } + } + /** * Strategy to convert EvalPython logical operator to physical operator. */ @@ -793,6 +812,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout, hasInitialState, planLater(initialState), planLater(child) ) :: Nil + case _: FlatMapGroupsInPandasWithState => + // TODO(SPARK-XXXXX): Implement batch support for applyInPandasWithState + throw new UnsupportedOperationException("applyInPandasWithState is unsupported.") case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala new file mode 100644 index 0000000000000..95e993f0685b2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunnerWithState.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ +import java.nio.charset.StandardCharsets + +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.api.python._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * [[ArrowPythonRunner]] with [[org.apache.spark.sql.streaming.GroupState]]. + */ +class ArrowPythonRunnerWithState( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + protected override val schema: StructType, + protected override val timeZoneId: String, + protected override val workerConf: Map[String, String], + oldState: GroupStateImpl[Row], + deserializer: ExpressionEncoder.Deserializer[Row], + stateType: StructType) + extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets) + with PythonArrowInput + with PythonArrowOutput { + + override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + + override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize + require( + bufferSize >= 4, + "Pandas execution requires more than 4 bytes. Please set higher buffer. " + + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") + + var newGroupState: GroupStateImpl[Row] = _ + + protected override def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { + super.handleMetadataBeforeExec(stream) + + // 1. Send JSON-serialized GroupState + PythonRDD.writeUTF(oldState.json(), stream) + + // 2. Send pickled Row from the GroupState + val rowInState = oldState.getOption.map(PythonSQLUtils.toPyRow).getOrElse(Array.empty) + stream.writeInt(rowInState.length) + if (rowInState.length > 0) { + stream.write(rowInState) + } + + // 3. Send the state type to serialize the output state back from Python. + PythonRDD.writeUTF(stateType.json, stream) + } + + protected override def handleMetadataAfterExec(stream: DataInputStream): Unit = { + super.handleMetadataAfterExec(stream) + + implicit val formats = org.json4s.DefaultFormats + + // 1. Receive JSON-serialized GroupState + val jsonStr = new Array[Byte](stream.readInt()) + stream.readFully(jsonStr) + val properties = parse(new String(jsonStr, StandardCharsets.UTF_8)) + + // 2. Receive and deserialized pickled Row to JVM Row. + val length = stream.readInt() + val maybeRow = if (length > 0) { + val pickledRow = new Array[Byte](length) + stream.readFully(pickledRow) + Some(PythonSQLUtils.toJVMRow(pickledRow, stateType, deserializer)) + } else { + None + } + + // 3. Create a group state. + newGroupState = GroupStateImpl.fromJson(maybeRow, properties) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index e830ea6b54662..b39787b12a484 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -78,8 +78,8 @@ case class FlatMapCoGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { - val (leftDedup, leftArgOffsets) = resolveArgOffsets(left, leftGroup) - val (rightDedup, rightArgOffsets) = resolveArgOffsets(right, rightGroup) + val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup) + val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, rightGroup) // Map cogrouped rows to ArrowPythonRunner results, Only execute if partition is not empty left.execute().zipPartitions(right.execute()) { (leftData, rightData) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 3a3a6022f9985..f0e815e966e79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -75,7 +75,7 @@ case class FlatMapGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes) + val (dedupAttributes, argOffsets) = resolveArgOffsets(child.output, groupingAttributes) // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala new file mode 100644 index 0000000000000..4ccd44a0b4297 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Distribution +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.python.PandasGroupUtils.{executePython, resolveArgOffsets} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.CompletionIterator + +/** + * Physical operator for executing + * [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]] + * + * @param functionExpr function called on each group + * @param groupingAttributes used to group the data + * @param outAttributes used to define the output rows + * @param stateType used to serialize/deserialize state before calling `functionExpr` + * @param stateInfo `StatefulOperatorStateInfo` to identify the state store for a given operator. + * @param stateFormatVersion the version of state format. + * @param outputMode the output mode of `functionExpr` + * @param timeoutConf used to timeout groups that have not received data in a while + * @param batchTimestampMs processing timestamp of the current batch. + * @param eventTimeWatermark event time watermark for the current batch + * @param child logical plan of the underlying data + */ +case class FlatMapGroupsInPandasWithStateExec( + functionExpr: Expression, + groupingAttributes: Seq[Attribute], + outAttributes: Seq[Attribute], + stateType: StructType, + stateInfo: Option[StatefulOperatorStateInfo], + stateFormatVersion: Int, + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + batchTimestampMs: Option[Long], + eventTimeWatermark: Option[Long], + child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { + + // TODO(SPARK-XXXXX): Add the support of initial state. + override protected val initialStateDeserializer: Expression = null + override protected val initialStateGroupAttrs: Seq[Attribute] = null + override protected val initialStateDataAttrs: Seq[Attribute] = null + override protected val initialState: SparkPlan = null + override protected val hasInitialState: Boolean = false + + override protected val stateEncoder: ExpressionEncoder[Any] = + RowEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]] + + override def output: Seq[Attribute] = outAttributes + + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func + private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) + private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets( + groupingAttributes ++ child.output, groupingAttributes) + private lazy val unsafeProj = UnsafeProjection.create(dedupAttributes, child.output) + + override def requiredChildDistribution: Seq[Distribution] = + StatefulOperatorPartitioning.getCompatibleDistribution( + groupingAttributes, getStateInfo, conf) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( + groupingAttributes.map(SortOrder(_, Ascending))) + + override def shortName: String = "applyInPandasWithState" + + override protected def withNewChildInternal( + newChild: SparkPlan): FlatMapGroupsInPandasWithStateExec = copy(child = newChild) + + override def createInputProcessor( + store: StateStore): InputProcessor = new InputProcessor(store: StateStore) { + private val stateDeserializer = + stateEncoder.asInstanceOf[ExpressionEncoder[Row]].createDeserializer() + + def callFunctionAndUpdateState( + stateData: StateData, + valueRowIter: Iterator[InternalRow], + hasTimedOut: Boolean): Iterator[InternalRow] = { + val groupedState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r }, + batchTimestampMs.getOrElse(NO_TIMESTAMP), + eventTimeWatermark.getOrElse(NO_TIMESTAMP), + timeoutConf, + hasTimedOut = hasTimedOut, + watermarkPresent).asInstanceOf[GroupStateImpl[Row]] + + val runner = new ArrowPythonRunnerWithState( + chainedFunc, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + Array(argOffsets), + StructType.fromAttributes(dedupAttributes), + sessionLocalTimeZone, + pythonRunnerConf, + groupedState, + stateDeserializer, + stateType) + + val inputIter = if (hasTimedOut) { + val joinedKeyRow = unsafeProj( + new JoinedRow( + stateData.keyRow, + new GenericInternalRow(Array.fill(dedupAttributes.length)(null: Any)))) + Iterator.single(Iterator.single(joinedKeyRow)) + } else { + Iterator.single(valueRowIter.map(unsafeProj)) + } + + val ret = executePython(inputIter, output, runner).toArray + numOutputRows += ret.length + val newGroupState: GroupStateImpl[Row] = runner.newGroupState + assert(newGroupState != null) + + // When the iterator is consumed, then write changes to state + def onIteratorCompletion: Unit = { + if (newGroupState.isRemoved && !newGroupState.getTimeoutTimestampMs.isPresent()) { + stateManager.removeState(store, stateData.keyRow) + numRemovedStateRows += 1 + } else { + val currentTimeoutTimestamp = newGroupState.getTimeoutTimestampMs.orElse(NO_TIMESTAMP) + val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp + val shouldWriteState = newGroupState.isUpdated || newGroupState.isRemoved || + hasTimeoutChanged + + if (shouldWriteState) { + val updatedStateObj = if (newGroupState.exists) newGroupState.get else null + stateManager.putState(store, stateData.keyRow, updatedStateObj, + currentTimeoutTimestamp) + numUpdatedStateRows += 1 + } + } + } + + // Return an iterator of rows such that fully consumed, the updated state value will + // be saved + CompletionIterator[InternalRow, Iterator[InternalRow]](ret.iterator, onIteratorCompletion) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala index 2da0000dad4ef..078876664062d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala @@ -24,7 +24,7 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.BasePythonRunner import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} -import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan} +import org.apache.spark.sql.execution.GroupedIterator import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** @@ -88,9 +88,10 @@ private[python] object PandasGroupUtils { * argOffsets[argOffsets[0]+2 .. ] is the arg offsets for data attributes */ def resolveArgOffsets( - child: SparkPlan, groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = { + attributes: Seq[Attribute], + groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = { - val dataAttributes = child.output.drop(groupingAttributes.length) + val dataAttributes = attributes.drop(groupingAttributes.length) val groupingIndicesInData = groupingAttributes.map { attribute => dataAttributes.indexWhere(attribute.semanticEquals) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 3ff539b9ef32b..ee9449151758b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expressi import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} @@ -33,59 +34,35 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} /** * Physical operator for executing `FlatMapGroupsWithState` - * - * @param func function called on each group - * @param keyDeserializer used to extract the key object for each group. - * @param valueDeserializer used to extract the items in the iterator from an input row. - * @param initialStateDeserializer used to extract the state object from the initialState dataset - * @param groupingAttributes used to group the data - * @param dataAttributes used to read the data - * @param outputObjAttr Defines the output object - * @param stateEncoder used to serialize/deserialize state before calling `func` - * @param outputMode the output mode of `func` - * @param timeoutConf used to timeout groups that have not received data in a while - * @param batchTimestampMs processing timestamp of the current batch. - * @param eventTimeWatermark event time watermark for the current batch - * @param initialState the user specified initial state - * @param hasInitialState indicates whether the initial state is provided or not - * @param child the physical plan for the underlying data */ -case class FlatMapGroupsWithStateExec( - func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], - keyDeserializer: Expression, - valueDeserializer: Expression, - initialStateDeserializer: Expression, - groupingAttributes: Seq[Attribute], - initialStateGroupAttrs: Seq[Attribute], - dataAttributes: Seq[Attribute], - initialStateDataAttrs: Seq[Attribute], - outputObjAttr: Attribute, - stateInfo: Option[StatefulOperatorStateInfo], - stateEncoder: ExpressionEncoder[Any], - stateFormatVersion: Int, - outputMode: OutputMode, - timeoutConf: GroupStateTimeout, - batchTimestampMs: Option[Long], - eventTimeWatermark: Option[Long], - initialState: SparkPlan, - hasInitialState: Boolean, - child: SparkPlan - ) extends BinaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { - - import FlatMapGroupsWithStateExecHelper._ +trait FlatMapGroupsWithStateExecBase + extends StateStoreWriter with WatermarkSupport { import GroupStateImpl._ + import FlatMapGroupsWithStateExecHelper._ - override def left: SparkPlan = child + protected val groupingAttributes: Seq[Attribute] - override def right: SparkPlan = initialState + protected val initialStateDeserializer: Expression + protected val initialStateGroupAttrs: Seq[Attribute] + protected val initialStateDataAttrs: Seq[Attribute] + protected val initialState: SparkPlan + protected val hasInitialState: Boolean + + val stateInfo: Option[StatefulOperatorStateInfo] + protected val stateEncoder: ExpressionEncoder[Any] + protected val stateFormatVersion: Int + protected val outputMode: OutputMode + protected val timeoutConf: GroupStateTimeout + protected val batchTimestampMs: Option[Long] + val eventTimeWatermark: Option[Long] private val isTimeoutEnabled = timeoutConf != NoTimeout - private val watermarkPresent = child.output.exists { + protected val watermarkPresent: Boolean = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true case _ => false } - private[sql] val stateManager = + lazy val stateManager: StateManager = createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion) /** @@ -240,7 +217,7 @@ case class FlatMapGroupsWithStateExec( stateManager.stateSchema, numColsPrefixKey = 0, stateInfo.get.storeVersion, storeConf, hadoopConfBroadcast.value.value) - val processor = new InputProcessor(store) + val processor = createInputProcessor(store) processDataWithPartition(childDataIterator, store, processor, Some(initStateIterator)) } } else { @@ -252,21 +229,15 @@ case class FlatMapGroupsWithStateExec( session.sqlContext.sessionState, Some(session.sqlContext.streams.stateStoreCoordinator) ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) => - val processor = new InputProcessor(store) + val processor = createInputProcessor(store) processDataWithPartition(singleIterator, store, processor) } } } - /** Helper class to update the state store */ - class InputProcessor(store: StateStore) { + def createInputProcessor(store: StateStore): InputProcessor - // Converters for translating input keys, values, output data between rows and Java objects - private val getKeyObj = - ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) - private val getValueObj = - ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) - private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType) + abstract class InputProcessor(store: StateStore) { private val getStateObj = if (hasInitialState) { Some(ObjectOperator.deserializeRowToObject(initialStateDeserializer, initialStateDataAttrs)) } else { @@ -274,9 +245,9 @@ case class FlatMapGroupsWithStateExec( } // Metrics - private val numUpdatedStateRows = longMetric("numUpdatedStateRows") - private val numOutputRows = longMetric("numOutputRows") - private val numRemovedStateRows = longMetric("numRemovedStateRows") + protected val numUpdatedStateRows: SQLMetric = longMetric("numUpdatedStateRows") + protected val numOutputRows: SQLMetric = longMetric("numOutputRows") + protected val numRemovedStateRows: SQLMetric = longMetric("numRemovedStateRows") /** * For every group, get the key, values and corresponding state and call the function, @@ -300,8 +271,7 @@ case class FlatMapGroupsWithStateExec( */ def processNewDataWithInitialState( childDataIter: Iterator[InternalRow], - initStateIter: Iterator[InternalRow] - ): Iterator[InternalRow] = { + initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = { if (!childDataIter.hasNext && !initStateIter.hasNext) return Iterator.empty @@ -313,7 +283,7 @@ case class FlatMapGroupsWithStateExec( // Create a CoGroupedIterator that will group the two iterators together for every key group. new CoGroupedIterator( - groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap { + groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap { case (keyRow, valueRowIter, initialStateRowIter) => val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] var foundInitialStateForKey = false @@ -328,8 +298,8 @@ case class FlatMapGroupsWithStateExec( // We apply the values for the key after applying the initial state. callFunctionAndUpdateState( stateManager.getState(store, keyUnsafeRow), - valueRowIter, - hasTimedOut = false + valueRowIter, + hasTimedOut = false ) } } @@ -362,7 +332,74 @@ case class FlatMapGroupsWithStateExec( * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty * @param hasTimedOut Whether this function is being called for a key timeout */ - private def callFunctionAndUpdateState( + protected def callFunctionAndUpdateState( + stateData: StateData, + valueRowIter: Iterator[InternalRow], + hasTimedOut: Boolean): Iterator[InternalRow] + } +} + +/** + * Physical operator for executing `FlatMapGroupsWithState` + * + * @param func function called on each group + * @param keyDeserializer used to extract the key object for each group. + * @param valueDeserializer used to extract the items in the iterator from an input row. + * @param initialStateDeserializer used to extract the state object from the initialState dataset + * @param groupingAttributes used to group the data + * @param dataAttributes used to read the data + * @param outputObjAttr Defines the output object + * @param stateEncoder used to serialize/deserialize state before calling `func` + * @param outputMode the output mode of `func` + * @param timeoutConf used to timeout groups that have not received data in a while + * @param batchTimestampMs processing timestamp of the current batch. + * @param eventTimeWatermark event time watermark for the current batch + * @param initialState the user specified initial state + * @param hasInitialState indicates whether the initial state is provided or not + * @param child the physical plan for the underlying data + */ +case class FlatMapGroupsWithStateExec( + func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + initialStateDeserializer: Expression, + groupingAttributes: Seq[Attribute], + initialStateGroupAttrs: Seq[Attribute], + dataAttributes: Seq[Attribute], + initialStateDataAttrs: Seq[Attribute], + outputObjAttr: Attribute, + stateInfo: Option[StatefulOperatorStateInfo], + stateEncoder: ExpressionEncoder[Any], + stateFormatVersion: Int, + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + batchTimestampMs: Option[Long], + eventTimeWatermark: Option[Long], + initialState: SparkPlan, + hasInitialState: Boolean, + child: SparkPlan) + extends FlatMapGroupsWithStateExecBase with BinaryExecNode with ObjectProducerExec { + import GroupStateImpl._ + import FlatMapGroupsWithStateExecHelper._ + + override def left: SparkPlan = child + + override def right: SparkPlan = initialState + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): FlatMapGroupsWithStateExec = + copy(child = newLeft, initialState = newRight) + + override def createInputProcessor( + store: StateStore): InputProcessor = new InputProcessor(store) { + // Converters for translating input keys, values, output data between rows and Java objects + private val getKeyObj = + ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) + private val getValueObj = + ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) + private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType) + + override protected def callFunctionAndUpdateState( stateData: StateData, valueRowIter: Iterator[InternalRow], hasTimedOut: Boolean): Iterator[InternalRow] = { @@ -405,10 +442,6 @@ case class FlatMapGroupsWithStateExec( CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) } } - - override protected def withNewChildrenInternal( - newLeft: SparkPlan, newRight: SparkPlan): FlatMapGroupsWithStateExec = - copy(child = newLeft, initialState = newRight) } object FlatMapGroupsWithStateExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index b4f37125f4fa9..861ceabaf7f5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.execution.streaming import java.sql.Date import java.util.concurrent.TimeUnit +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.api.java.Optional import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, NoTimeout, ProcessingTimeTimeout} import org.apache.spark.sql.catalyst.util.IntervalUtils @@ -27,6 +30,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.streaming.GroupStateImpl._ import org.apache.spark.sql.streaming.{GroupStateTimeout, TestGroupState} import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils /** * Internal implementation of the [[TestGroupState]] interface. Methods are not thread-safe. @@ -39,7 +43,10 @@ import org.apache.spark.unsafe.types.UTF8String * @param hasTimedOut Whether the key for which this state wrapped is being created is * getting timed out or not. */ -private[sql] class GroupStateImpl[S] private( +private[sql] class GroupStateImpl[S] private[sql]( + // NOTE:if you're adding new properties here, fix: + // - `json` and `fromJson` methods of this class in Scala + // - pyspark.sql.streaming.state.GroupStateImpl in Python optionalValue: Option[S], batchProcessingTimeMs: Long, eventTimeWatermarkMs: Long, @@ -173,6 +180,22 @@ private[sql] class GroupStateImpl[S] private( throw QueryExecutionErrors.cannotSetTimeoutTimestampError() } } + + private[sql] def json(): String = compact(render(new JObject( + // Constructor + "optionalValue" -> JNull :: // Note that optionalValue will be manually serialized. + "batchProcessingTimeMs" -> JLong(batchProcessingTimeMs) :: + "eventTimeWatermarkMs" -> JLong(eventTimeWatermarkMs) :: + "timeoutConf" -> JString(Utils.stripDollars(Utils.getSimpleName(timeoutConf.getClass))) :: + "hasTimedOut" -> JBool(hasTimedOut) :: + "watermarkPresent" -> JBool(watermarkPresent) :: + + // Internal state + "defined" -> JBool(defined) :: + "updated" -> JBool(updated) :: + "removed" -> JBool(removed) :: + "timeoutTimestamp" -> JLong(timeoutTimestamp) :: Nil + ))) } @@ -214,4 +237,35 @@ private[sql] object GroupStateImpl { hasTimedOut = false, watermarkPresent) } + + def groupStateTimeoutFromString(clazz: String): GroupStateTimeout = clazz match { + case "ProcessingTimeTimeout" => GroupStateTimeout.ProcessingTimeTimeout + case "EventTimeTimeout" => GroupStateTimeout.EventTimeTimeout + case "NoTimeout" => GroupStateTimeout.NoTimeout + case _ => throw new IllegalStateException("Invalid string for GroupStateTimeout: " + clazz) + } + + def fromJson[S](key: Option[S], json: JValue): GroupStateImpl[S] = { + implicit val formats = org.json4s.DefaultFormats + + val hmap = json.extract[Map[String, Any]] + + // Constructor + val newGroupState = new GroupStateImpl[S]( + key, + hmap("batchProcessingTimeMs").asInstanceOf[Number].longValue(), + hmap("eventTimeWatermarkMs").asInstanceOf[Number].longValue(), + groupStateTimeoutFromString(hmap("timeoutConf").asInstanceOf[String]), + hmap("hasTimedOut").asInstanceOf[Boolean], + hmap("watermarkPresent").asInstanceOf[Boolean]) + + // Internal state + newGroupState.defined = hmap("defined").asInstanceOf[Boolean] + newGroupState.updated = hmap("updated").asInstanceOf[Boolean] + newGroupState.removed = hmap("removed").asInstanceOf[Boolean] + newGroupState.timeoutTimestamp = + hmap("timeoutTimestamp").asInstanceOf[Number].longValue() + + newGroupState + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 3f369ac5e973b..f386282a0b3e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.execution.{LocalLimitExec, QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, MergingSessionsExec, ObjectHashAggregateExec, SortAggregateExec, UpdatingSessionsExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike +import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode @@ -62,6 +63,7 @@ class IncrementalExecution( StreamingJoinStrategy :: StatefulAggregationStrategy :: FlatMapGroupsWithStateStrategy :: + FlatMapGroupsInPandasWithStateStrategy :: StreamingRelationStrategy :: StreamingDeduplicationStrategy :: StreamingGlobalLimitStrategy(outputMode) :: Nil @@ -210,6 +212,13 @@ class IncrementalExecution( hasInitialState = hasInitialState ) + case m: FlatMapGroupsInPandasWithStateExec => + m.copy( + stateInfo = Some(nextStatefulOperationStateInfo), + batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), + eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs) + ) + case j: StreamingSymmetricHashJoinExec => j.copy( stateInfo = Some(nextStatefulOperationStateInfo), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 01ff72bac7bcc..022fd1239ce4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -49,7 +49,7 @@ package object state { } /** Map each partition of an RDD along with data in a [[StateStore]]. */ - private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( + def mapPartitionsWithStateStore[U: ClassTag]( stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 827cfcf32fead..3c41f6b47b5ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} import scala.collection.JavaConverters._ @@ -31,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, Pyth import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.SparkUserDefinedFunction -import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType} /** * This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF, @@ -190,7 +191,7 @@ object IntegratedUDFTestUtils extends SQLHelper { throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") } - private lazy val pandasFunc: Array[Byte] = if (shouldTestScalarPandasUDFs) { + private lazy val pandasFunc: Array[Byte] = if (shouldTestPandasUDFs) { var binaryPandasFunc: Array[Byte] = null withTempPath { path => Process( @@ -213,7 +214,7 @@ object IntegratedUDFTestUtils extends SQLHelper { throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") } - private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestGroupedAggPandasUDFs) { + private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestPandasUDFs) { var binaryPandasFunc: Array[Byte] = null withTempPath { path => Process( @@ -235,6 +236,34 @@ object IntegratedUDFTestUtils extends SQLHelper { throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") } + private def createPandasGroupedMapFuncWithState(pythonScript: String): Array[Byte] = { + if (shouldTestPandasUDFs) { + var binaryPandasFunc: Array[Byte] = null + withTempPath { codePath => + Files.write(codePath.toPath, pythonScript.getBytes(StandardCharsets.UTF_8)) + withTempPath { path => + Process( + Seq( + pythonExec, + "-c", + "from pyspark.serializers import CloudPickleSerializer; " + + s"f = open('$path', 'wb');" + + s"exec(open('$codePath', 'r').read());" + + "f.write(CloudPickleSerializer().dumps((" + + "func, tpe)))"), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) + } + } + assert(binaryPandasFunc != null) + binaryPandasFunc + } else { + throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.") + } + } + + // Make sure this map stays mutable - this map gets updated later in Python runners. private val workerEnv = new java.util.HashMap[String, String]() workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") @@ -251,11 +280,9 @@ object IntegratedUDFTestUtils extends SQLHelper { lazy val shouldTestPythonUDFs: Boolean = isPythonAvailable && isPySparkAvailable - lazy val shouldTestScalarPandasUDFs: Boolean = + lazy val shouldTestPandasUDFs: Boolean = isPythonAvailable && isPandasAvailable && isPyArrowAvailable - lazy val shouldTestGroupedAggPandasUDFs: Boolean = shouldTestScalarPandasUDFs - /** * A base trait for various UDFs defined in this object. */ @@ -420,6 +447,41 @@ object IntegratedUDFTestUtils extends SQLHelper { val prettyName: String = "Grouped Aggregate Pandas UDF" } + /** + * Arbitrary stateful processing in Python is used for + * `DataFrame.groupBy.applyInPandasWithState`. It requires `pythonScript` to + * define `func` (Python function) and `tpe` (`StructType` for state key). + * + * Virtually equivalent to: + * + * {{{ + * # exec defines 'func' and 'tpe' (struct type for state key) + * exec(pythonScript) + * + * # ... are filled when this UDF is invoked, see also 'PythonFlatMapGroupsWithStateSuite'. + * df.groupBy(...).applyInPandasWithState(func, ..., tpe, ..., ...) + * }}} + */ + case class TestGroupedMapPandasUDFWithState(name: String, pythonScript: String) extends TestUDF { + private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction( + name = name, + func = SimplePythonFunction( + command = createPandasGroupedMapFuncWithState(pythonScript), + envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], + pythonIncludes = List.empty[String].asJava, + pythonExec = pythonExec, + pythonVer = pythonVer, + broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, + accumulator = null), + dataType = NullType, // This is not respected. + pythonEvalType = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + udfDeterministic = true) + + def apply(exprs: Column*): Column = udf(exprs: _*) + + val prettyName: String = "Grouped Map Pandas UDF with State" + } + /** * A Scala UDF that takes one column, casts into string, executes the * Scala native function, and casts back to the type of input column. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index bd48d17303952..e55f78aefd308 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -246,7 +246,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper /* Do nothing */ } case udfTestCase: UDFTest - if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && !shouldTestScalarPandasUDFs => + if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && !shouldTestPandasUDFs => ignore(s"${testCase.name} is skipped because pyspark," + s"pandas and/or pyarrow were not available in [$pythonExec].") { /* Do nothing */ @@ -435,7 +435,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper if udfTestCase.udf.isInstanceOf[TestPythonUDF] && shouldTestPythonUDFs => s"${testCase.name}${System.lineSeparator()}Python: $pythonVer${System.lineSeparator()}" case udfTestCase: UDFTest - if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && shouldTestScalarPandasUDFs => + if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && shouldTestPandasUDFs => s"${testCase.name}${System.lineSeparator()}" + s"Python: $pythonVer Pandas: $pandasVer PyArrow: $pyarrowVer${System.lineSeparator()}" case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 85aa7221b0ee6..37900db9360ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -125,7 +125,7 @@ class QueryCompilationErrorsSuite test("INVALID_PANDAS_UDF_PLACEMENT: Using aggregate function with grouped aggregate pandas UDF") { import IntegratedUDFTestUtils._ - assume(shouldTestGroupedAggPandasUDFs) + assume(shouldTestPandasUDFs) val df = Seq( (536361, "85123A", 2, 17850), @@ -176,7 +176,7 @@ class QueryCompilationErrorsSuite test("UNSUPPORTED_FEATURE: Using pandas UDF aggregate expression with pivot") { import IntegratedUDFTestUtils._ - assume(shouldTestGroupedAggPandasUDFs) + assume(shouldTestPandasUDFs) val df = Seq( (536361, "85123A", 2, 17850), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index 4ad7f90105373..42e4b1accde72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -73,7 +73,7 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { } test("SPARK-39962: Global aggregation of Pandas UDF should respect the column order") { - assume(shouldTestGroupedAggPandasUDFs) + assume(shouldTestPythonUDFs) val df = Seq[(java.lang.Integer, java.lang.Integer)]((1, null)).toDF("a", "b") val pandasTestUDF = TestGroupedAggPandasUDF(name = "pandas_udf") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala new file mode 100644 index 0000000000000..03d3fd6dcff1e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.IntegratedUDFTestUtils._ +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans.logical.{NoTimeout, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Complete, Update} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions.timestamp_seconds +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.types._ + +class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { + + import testImplicits._ + + test("applyInPandasWithState - streaming") { + assume(shouldTestPandasUDFs) + + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf, state): + | assert state.getCurrentProcessingTimeMs() >= 0 + | try: + | state.getCurrentWatermarkMs() + | assert False + | except RuntimeError as e: + | assert "watermark" in str(e) + | + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | count += len(pdf) + | if count == 3: + | state.remove() + | return pd.DataFrame() + | else: + | state.update((count,)) + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[String] + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val inputDataDS = inputData.toDS() + val result = + inputDataDS + .groupBy("value") + .applyInPandasWithState( + pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "NoTimeout") + + testStream(result, Update)( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckNewAnswer(("b", "2")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckNewAnswer(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) + ) + } + + test("applyInPandasWithState - streaming + aggregation") { + assume(shouldTestPandasUDFs) + + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf, state): + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | count += len(pdf) + | if count == 3: + | state.remove() + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | else: + | state.update((count,)) + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[String] + val inputDataDS = inputData.toDS + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputDataDS + .groupBy("value") + .applyInPandasWithState( + pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Append", + "NoTimeout") + .groupBy("key") + .count() + + testStream(result, Complete)( + AddData(inputData, "a"), + CheckNewAnswer(("a", 1)), + AddData(inputData, "a", "b"), + // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 + CheckNewAnswer(("a", 2), ("b", 1)), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), + // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; + // so increment a and b by 1 + CheckNewAnswer(("a", 3), ("b", 2)), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), + // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; + // so increment a and c by 1 + CheckNewAnswer(("a", 4), ("b", 2), ("c", 1)) + ) + } + + test("applyInPandasWithState - streaming with processing time timeout") { + assume(shouldTestPandasUDFs) + + // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. + // It returns the count if changed, or -1 if the state was removed by timeout. + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf, state): + | assert state.getCurrentProcessingTimeMs() >= 0 + | try: + | state.getCurrentWatermarkMs() + | assert False + | except RuntimeError as e: + | assert "watermark" in str(e) + | + | if state.hasTimedOut: + | state.remove() + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | else: + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | count += len(pdf) + | state.update((count,)) + | state.setTimeoutDuration(10000) + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val clock = new StreamManualClock + val inputData = MemoryStream[String] + val inputDataDS = inputData.toDS + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputDataDS + .groupBy("value") + .applyInPandasWithState( + pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "ProcessingTimeTimeout") + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("b", "1")), + assertNumStateRows(total = 2, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(10 * 1000), + CheckNewAnswer(("a", "-1"), ("b", "2")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), + + StopStream, + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + + AddData(inputData, "c"), + AdvanceManualClock(11 * 1000), + CheckNewAnswer(("b", "-1"), ("c", "1")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))), + + AdvanceManualClock(12 * 1000), + AssertOnQuery { _ => clock.getTimeMillis() == 35000 }, + Execute { q => + failAfter(streamingTimeout) { + while (q.lastProgress.timestamp != "1970-01-01T00:00:35.000Z") { + Thread.sleep(1) + } + } + }, + CheckNewAnswer(("c", "-1")), + assertNumStateRows( + total = Seq(0), updated = Seq(0), droppedByWatermark = Seq(0), removed = Some(Seq(1))) + ) + } + + test("applyInPandasWithState - streaming w/ event time timeout + watermark") { + assume(shouldTestPandasUDFs) + + // timestamp_seconds assumes the base timezone is UTC. However, the provided function + // localizes it. Therefore, this test assumes the timezone is in UTC + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + val pythonScript = + """ + |import calendar + |import os + |import datetime + |import pandas as pd + |from pyspark.sql.types import StructType, StringType, StructField, IntegerType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("maxEventTimeSec", IntegerType())]) + | + |def func(key, pdf, state): + | assert state.getCurrentProcessingTimeMs() >= 0 + | assert state.getCurrentWatermarkMs() >= -1 + | + | timeout_delay_sec = 5 + | if state.hasTimedOut: + | state.remove() + | return pd.DataFrame({'key': [key[0]], 'maxEventTimeSec': [-1]}) + | else: + | m = state.getOption + | if m is None: + | m = 0 + | else: + | m = m[0] + | + | pser = pdf.eventTime.apply( + | lambda dt: (int(calendar.timegm(dt.utctimetuple()) + dt.microsecond))) + | max_event_time_sec = int(max(pser.max(), m)) + | timeout_timestamp_sec = max_event_time_sec + timeout_delay_sec + | state.update((max_event_time_sec,)) + | state.setTimeoutTimestamp(timeout_timestamp_sec * 1000) + | return pd.DataFrame({'key': [key[0]], 'maxEventTimeSec': [max_event_time_sec]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val inputData = MemoryStream[(String, Int)] + val inputDataDF = + inputData.toDF.select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime")) + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("maxEventTimeSec", IntegerType))) + val stateStructType = StructType(Seq(StructField("maxEventTimeSec", LongType))) + val result = + inputDataDF + .withWatermark("eventTime", "10 seconds") + .groupBy("key") + .applyInPandasWithState( + pythonFunc(inputDataDF("key"), inputDataDF("eventTime")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "EventTimeTimeout") + + testStream(result, Update)( + StartStream(), + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckNewAnswer(), // No output as data should get filtered by watermark + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 + ) + } + } + + def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = { + test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) { + assume(shouldTestPandasUDFs) + + // timestamp_seconds assumes the base timezone is UTC. However, the provided function + // localizes it. Therefore, this test assumes the timezone is in UTC + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + // String, (String, Long), RunningCount(Long) + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key", StringType()), + | StructField("countAsString", StringType())]) + | + |def func(key, pdf, state): + | if state.hasTimedOut: + | state.remove() + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(-1)]}) + | else: + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | count += len(pdf) + | state.update((count,)) + | state.setTimeoutDuration(10000) + | return pd.DataFrame({'key': [key[0]], 'countAsString': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val clock = new StreamManualClock + val inputData = MemoryStream[(String, Long)] + val inputDataDF = inputData + .toDF.toDF("key", "time") + .selectExpr("key", "timestamp_seconds(time) as timestamp") + val outputStructType = StructType( + Seq( + StructField("key", StringType), + StructField("countAsString", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputDataDF + .withWatermark("timestamp", "10 second") + .groupBy("key") + .applyInPandasWithState( + pythonFunc(inputDataDF("key"), inputDataDF("timestamp")).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "ProcessingTimeTimeout") + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, ("a", 1L)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")) + ) + } + } + } + testWithTimeout(NoTimeout) + testWithTimeout(ProcessingTimeTimeout) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 9d34ceea8dd47..b7c9aa4178090 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -1733,7 +1733,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( mapGroupsFunc, timeoutConf, currentBatchTimestamp) - val inputProcessor = new mapGroupsSparkPlan.InputProcessor(store) + val inputProcessor = mapGroupsSparkPlan.createInputProcessor(store) val stateManager = mapGroupsSparkPlan.stateManager val key = intToRow(0) // Prepare store with prior state configs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 26c201d5921ed..fc6b51dce790b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -279,7 +279,7 @@ class ContinuousSuite extends ContinuousSuiteBase { Seq(TestScalaUDF("udf"), TestPythonUDF("udf"), TestScalarPandasUDF("udf")).foreach { udf => test(s"continuous mode with various UDFs - ${udf.prettyName}") { assume( - shouldTestScalarPandasUDFs && udf.isInstanceOf[TestScalarPandasUDF] || + shouldTestPandasUDFs && udf.isInstanceOf[TestScalarPandasUDF] || shouldTestPythonUDFs && udf.isInstanceOf[TestPythonUDF] || udf.isInstanceOf[TestScalaUDF])