-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark #37893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
444f9a4
79ba311
0000994
caa4924
76bb404
a2f25e3
bb3a80a
4d3c7e9
b93e488
5d23a6d
e757be0
43929ac
e725995
9c6bf60
69bb3e8
516fa4f
3ded763
f00486f
2fb8da0
1a4f158
c5b35a4
a95df28
0fee506
7051799
5142941
95a1400
295bc9b
9f52c80
8d23d46
fc3dca2
119daf3
4b7c667
426f5e7
83f2555
38eec2d
8133dcd
f100048
dd7a655
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,18 +15,20 @@ | |||||||||
| # limitations under the License. | ||||||||||
| # | ||||||||||
| import sys | ||||||||||
| from typing import List, Union, TYPE_CHECKING | ||||||||||
| from typing import List, Union, TYPE_CHECKING, cast | ||||||||||
| import warnings | ||||||||||
|
|
||||||||||
| from pyspark.rdd import PythonEvalType | ||||||||||
| from pyspark.sql.column import Column | ||||||||||
| from pyspark.sql.dataframe import DataFrame | ||||||||||
| from pyspark.sql.types import StructType | ||||||||||
| from pyspark.sql.streaming.state import GroupStateTimeout | ||||||||||
| from pyspark.sql.types import StructType, _parse_datatype_string | ||||||||||
|
|
||||||||||
| if TYPE_CHECKING: | ||||||||||
| from pyspark.sql.pandas._typing import ( | ||||||||||
| GroupedMapPandasUserDefinedFunction, | ||||||||||
| PandasGroupedMapFunction, | ||||||||||
| PandasGroupedMapFunctionWithState, | ||||||||||
| PandasCogroupedMapFunction, | ||||||||||
| ) | ||||||||||
| from pyspark.sql.group import GroupedData | ||||||||||
|
|
@@ -216,6 +218,125 @@ def applyInPandas( | |||||||||
| jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) | ||||||||||
| return DataFrame(jdf, self.session) | ||||||||||
|
|
||||||||||
| def applyInPandasWithState( | ||||||||||
| self, | ||||||||||
| func: "PandasGroupedMapFunctionWithState", | ||||||||||
| outputStructType: Union[StructType, str], | ||||||||||
| stateStructType: Union[StructType, str], | ||||||||||
| outputMode: str, | ||||||||||
| timeoutConf: str, | ||||||||||
| ) -> DataFrame: | ||||||||||
| """ | ||||||||||
| Applies the given function to each group of data, while maintaining a user-defined | ||||||||||
| per-group state. The result Dataset will represent the flattened record returned by the | ||||||||||
| function. | ||||||||||
|
|
||||||||||
| For a streaming Dataset, the function will be invoked first for all input groups and then | ||||||||||
| for all timed out states where the input data is set to be empty. Updates to each group's | ||||||||||
| state will be saved across invocations. | ||||||||||
|
|
||||||||||
| The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and | ||||||||||
| return another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple | ||||||||||
| of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as | ||||||||||
| :class:`pyspark.sql.streaming.state.GroupState`. | ||||||||||
|
|
||||||||||
| For each group, all columns are passed together as `pandas.DataFrame` to the user-function, | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 'all columns are passed together as Each group is passed as one or more pandas.DataFrame group of records with all columns packed into the DataFrame.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, this follows the existing method doc in applyInPandas. I'm OK with change it though as I agree it's not mandatory to call out all columns will be passed. Neither user function nor public API specify columns, which is implicitly expected to all columns. Probably worth to discuss a bit more and change altogether in both function? cc. @HyukjinKwon |
||||||||||
| and the returned `pandas.DataFrame` across all invocations are combined as a | ||||||||||
| :class:`DataFrame`. Note that the user function should not make a guess of the number of | ||||||||||
| elements in the iterator. To process all data, the user function needs to iterate all | ||||||||||
| elements and process them. On the other hand, the user function is not strictly required to | ||||||||||
| iterate through all elements in the iterator if it intends to read a part of data. | ||||||||||
|
|
||||||||||
| The `outputStructType` should be a :class:`StructType` describing the schema of all | ||||||||||
| elements in the returned value, `pandas.DataFrame`. The column labels of all elements in | ||||||||||
| returned `pandas.DataFrame` must either match the field names in the defined schema if | ||||||||||
| specified as strings, or match the field data types by position if not strings, | ||||||||||
| e.g. integer indices. | ||||||||||
|
|
||||||||||
| The `stateStructType` should be :class:`StructType` describing the schema of the | ||||||||||
| user-defined state. The value of the state will be presented as a tuple, as well as the | ||||||||||
| update should be performed with the tuple. The corresponding Python types for | ||||||||||
| :class:DataType are supported. Please refer to the page | ||||||||||
| https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python tab). | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
|
||||||||||
| The size of each DataFrame in both the input and output can be arbitrary. The number of | ||||||||||
| DataFrames in both the input and output can also be arbitrary. | ||||||||||
|
Comment on lines
+262
to
+263
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can extract some notes from the description to |
||||||||||
|
|
||||||||||
| .. versionadded:: 3.4.0 | ||||||||||
|
|
||||||||||
| Parameters | ||||||||||
| ---------- | ||||||||||
| func : function | ||||||||||
| a Python native function to be called on every group. It should take parameters | ||||||||||
| (key, Iterator[`pandas.DataFrame`], state) and return Iterator[`pandas.DataFrame`]. | ||||||||||
| Note that the type of the key is tuple and the type of the state is | ||||||||||
| :class:`pyspark.sql.streaming.state.GroupState`. | ||||||||||
| outputStructType : :class:`pyspark.sql.types.DataType` or str | ||||||||||
| the type of the output records. The value can be either a | ||||||||||
| :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you provide an example here of the string?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the doc or here? All other PySpark method docs do not have example of this string. Maybe we could have examples like other APIs do and provide DDL-formatted type string to compensate. |
||||||||||
| stateStructType : :class:`pyspark.sql.types.DataType` or str | ||||||||||
| the type of the user-defined state. The value can be either a | ||||||||||
| :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same - can you provide an example of the string
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same. |
||||||||||
| outputMode : str | ||||||||||
| the output mode of the function. | ||||||||||
| timeoutConf : str | ||||||||||
| timeout configuration for groups that do not receive data for a while. valid values | ||||||||||
| are defined in :class:`pyspark.sql.streaming.state.GroupStateTimeout`. | ||||||||||
|
|
||||||||||
| Examples | ||||||||||
| -------- | ||||||||||
| >>> import pandas as pd # doctest: +SKIP | ||||||||||
| >>> from pyspark.sql.streaming.state import GroupStateTimeout | ||||||||||
| >>> def count_fn(key, pdf_iter, state): | ||||||||||
| ... assert isinstance(state, GroupStateImpl) | ||||||||||
| ... total_len = 0 | ||||||||||
| ... for pdf in pdf_iter: | ||||||||||
| ... total_len += len(pdf) | ||||||||||
| ... state.update((total_len,)) | ||||||||||
| ... yield pd.DataFrame({"id": [key[0]], "countAsString": [str(total_len)]}) | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| >>> df.groupby("id").applyInPandasWithState( | ||||||||||
| ... count_fn, outputStructType="id long, countAsString string", | ||||||||||
| ... stateStructType="len long", outputMode="Update", | ||||||||||
| ... timeoutConf=GroupStateTimeout.NoTimeout) # doctest: +SKIP | ||||||||||
|
|
||||||||||
| Notes | ||||||||||
| ----- | ||||||||||
| This function requires a full shuffle. | ||||||||||
|
|
||||||||||
| This API is experimental. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| from pyspark.sql import GroupedData | ||||||||||
| from pyspark.sql.functions import pandas_udf | ||||||||||
|
|
||||||||||
| assert isinstance(self, GroupedData) | ||||||||||
| assert timeoutConf in [ | ||||||||||
| GroupStateTimeout.NoTimeout, | ||||||||||
| GroupStateTimeout.ProcessingTimeTimeout, | ||||||||||
| GroupStateTimeout.EventTimeTimeout, | ||||||||||
| ] | ||||||||||
|
|
||||||||||
| if isinstance(outputStructType, str): | ||||||||||
| outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) | ||||||||||
| if isinstance(stateStructType, str): | ||||||||||
| stateStructType = cast(StructType, _parse_datatype_string(stateStructType)) | ||||||||||
|
|
||||||||||
| udf = pandas_udf( | ||||||||||
| func, # type: ignore[call-overload] | ||||||||||
| returnType=outputStructType, | ||||||||||
| functionType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, | ||||||||||
| ) | ||||||||||
| df = self._df | ||||||||||
| udf_column = udf(*[df[col] for col in df.columns]) | ||||||||||
| jdf = self._jgd.applyInPandasWithState( | ||||||||||
| udf_column._jc.expr(), | ||||||||||
| self.session._jsparkSession.parseDataType(outputStructType.json()), | ||||||||||
| self.session._jsparkSession.parseDataType(stateStructType.json()), | ||||||||||
| outputMode, | ||||||||||
| timeoutConf, | ||||||||||
| ) | ||||||||||
| return DataFrame(jdf, self.session) | ||||||||||
|
|
||||||||||
| def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps": | ||||||||||
| """ | ||||||||||
| Cogroups this group with another group so that we can run cogrouped operations. | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.