Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
444f9a4
[SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark
HeartSaVioR Sep 15, 2022
79ba311
meta-commit to credit properly on co-authorship
HyukjinKwon Sep 15, 2022
0000994
replace SPARK-XXXXX with real JIRA ticket
HeartSaVioR Sep 15, 2022
caa4924
fix
HeartSaVioR Sep 16, 2022
76bb404
Add missing piece
HeartSaVioR Sep 16, 2022
a2f25e3
unused import
HeartSaVioR Sep 16, 2022
bb3a80a
fix for Scala 2.13
HeartSaVioR Sep 16, 2022
4d3c7e9
reformat as suggested by linter
HeartSaVioR Sep 16, 2022
b93e488
Update sql/core/src/main/scala/org/apache/spark/sql/execution/python/…
HeartSaVioR Sep 19, 2022
5d23a6d
Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/anal…
HeartSaVioR Sep 19, 2022
e757be0
1st reflection of feedbacks
HeartSaVioR Sep 19, 2022
43929ac
reflect suggestion
HeartSaVioR Sep 19, 2022
e725995
WIP still documenting...
HeartSaVioR Sep 19, 2022
9c6bf60
Rename GroupStateImpl to GroupState in PySpark (No additional interface)
HeartSaVioR Sep 19, 2022
69bb3e8
WIP still updating the doc
HeartSaVioR Sep 19, 2022
516fa4f
Revert "Update sql/core/src/main/scala/org/apache/spark/sql/execution…
HeartSaVioR Sep 19, 2022
3ded763
fix style
HeartSaVioR Sep 20, 2022
f00486f
WIP still documenting...
HeartSaVioR Sep 20, 2022
2fb8da0
Change the condition of constructing Arrow RecordBatch to number of r…
HeartSaVioR Sep 20, 2022
1a4f158
small fix
HeartSaVioR Sep 20, 2022
c5b35a4
more doc
HeartSaVioR Sep 20, 2022
a95df28
still documenting...
HeartSaVioR Sep 20, 2022
0fee506
slight refactor
HeartSaVioR Sep 20, 2022
7051799
further documentation
HeartSaVioR Sep 20, 2022
5142941
further document...
HeartSaVioR Sep 20, 2022
95a1400
further doc
HeartSaVioR Sep 20, 2022
295bc9b
apply suggestion
HeartSaVioR Sep 20, 2022
9f52c80
update the doc
HeartSaVioR Sep 20, 2022
8d23d46
Add example
HeartSaVioR Sep 21, 2022
fc3dca2
slight addition
HeartSaVioR Sep 21, 2022
119daf3
fix style
HeartSaVioR Sep 21, 2022
4b7c667
fix bug during updates of code review
HeartSaVioR Sep 21, 2022
426f5e7
fix another bug during code review
HeartSaVioR Sep 21, 2022
83f2555
fix on pydoc
HeartSaVioR Sep 21, 2022
38eec2d
Fix a silly bug where the value of the state is removed or not initia…
HeartSaVioR Sep 21, 2022
8133dcd
fix an edge-case being figured out from newer test case
HeartSaVioR Sep 21, 2022
f100048
loosen the requirement
HeartSaVioR Sep 21, 2022
dd7a655
reflect feedbacks
HeartSaVioR Sep 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ private[spark] object PythonEvalType {
val SQL_MAP_PANDAS_ITER_UDF = 205
val SQL_COGROUPED_MAP_PANDAS_UDF = 206
val SQL_MAP_ARROW_ITER_UDF = 207
val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208

def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
Expand All @@ -65,6 +66,7 @@ private[spark] object PythonEvalType {
case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF"
case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE"
}
}

Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
PandasMapIterUDFType,
PandasCogroupedMapUDFType,
ArrowMapIterUDFType,
PandasGroupedMapUDFWithStateType,
)
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import AtomicType, StructType
Expand Down Expand Up @@ -147,6 +148,7 @@ class PythonEvalType:
SQL_MAP_PANDAS_ITER_UDF: "PandasMapIterUDFType" = 205
SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206
SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207
SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208


def portable_hash(x: Hashable) -> int:
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/sql/pandas/_typing/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ from typing_extensions import Protocol, Literal
from types import FunctionType

from pyspark.sql._typing import LiteralType
from pyspark.sql.streaming.state import GroupState
from pandas.core.frame import DataFrame as PandasDataFrame
from pandas.core.series import Series as PandasSeries
from numpy import ndarray as NDArray
Expand All @@ -51,6 +52,7 @@ PandasScalarIterUDFType = Literal[204]
PandasMapIterUDFType = Literal[205]
PandasCogroupedMapUDFType = Literal[206]
ArrowMapIterUDFType = Literal[207]
PandasGroupedMapUDFWithStateType = Literal[208]

class PandasVariadicScalarToScalarFunction(Protocol):
def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ...
Expand Down Expand Up @@ -256,6 +258,10 @@ PandasGroupedMapFunction = Union[
Callable[[Any, DataFrameLike], DataFrameLike],
]

PandasGroupedMapFunctionWithState = Callable[
[Any, Iterable[DataFrameLike], GroupState], Iterable[DataFrameLike]
]

class PandasVariadicGroupedAggFunction(Protocol):
def __call__(self, *_: SeriesLike) -> LiteralType: ...

Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
None,
]: # None means it should infer the type from type hints.

Expand Down Expand Up @@ -402,6 +403,7 @@ def _create_pandas_udf(f, returnType, evalType):
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
]:
# In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is being triggered
# at `apply` instead.
Expand Down
125 changes: 123 additions & 2 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@
# limitations under the License.
#
import sys
from typing import List, Union, TYPE_CHECKING
from typing import List, Union, TYPE_CHECKING, cast
import warnings

from pyspark.rdd import PythonEvalType
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import StructType
from pyspark.sql.streaming.state import GroupStateTimeout
from pyspark.sql.types import StructType, _parse_datatype_string

if TYPE_CHECKING:
from pyspark.sql.pandas._typing import (
GroupedMapPandasUserDefinedFunction,
PandasGroupedMapFunction,
PandasGroupedMapFunctionWithState,
PandasCogroupedMapFunction,
)
from pyspark.sql.group import GroupedData
Expand Down Expand Up @@ -216,6 +218,125 @@ def applyInPandas(
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
return DataFrame(jdf, self.session)

def applyInPandasWithState(
self,
func: "PandasGroupedMapFunctionWithState",
outputStructType: Union[StructType, str],
stateStructType: Union[StructType, str],
outputMode: str,
timeoutConf: str,
) -> DataFrame:
"""
Applies the given function to each group of data, while maintaining a user-defined
per-group state. The result Dataset will represent the flattened record returned by the
function.

For a streaming Dataset, the function will be invoked first for all input groups and then
for all timed out states where the input data is set to be empty. Updates to each group's
state will be saved across invocations.
Comment on lines +234 to +236
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
For a streaming Dataset, the function will be invoked first for all input groups and then
for all timed out states where the input data is set to be empty. Updates to each group's
state will be saved across invocations.
For a streaming :class:`DataFrame`, the function will be invoked first for all input groups
and then for all timed out states where the input data is set to be empty. Updates to
each group's state will be saved across invocations.


The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and
return another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
:class:`pyspark.sql.streaming.state.GroupState`.

For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'all columns are passed together as pandas.DataFrame ...' - this is confusing - of course all columns will be passed together. How about:

Each group is passed as one or more pandas.DataFrame group of records with all columns packed into the DataFrame.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, this follows the existing method doc in applyInPandas.

I'm OK with change it though as I agree it's not mandatory to call out all columns will be passed. Neither user function nor public API specify columns, which is implicitly expected to all columns.

Probably worth to discuss a bit more and change altogether in both function? cc. @HyukjinKwon

and the returned `pandas.DataFrame` across all invocations are combined as a
:class:`DataFrame`. Note that the user function should not make a guess of the number of
elements in the iterator. To process all data, the user function needs to iterate all
elements and process them. On the other hand, the user function is not strictly required to
iterate through all elements in the iterator if it intends to read a part of data.

The `outputStructType` should be a :class:`StructType` describing the schema of all
elements in the returned value, `pandas.DataFrame`. The column labels of all elements in
returned `pandas.DataFrame` must either match the field names in the defined schema if
specified as strings, or match the field data types by position if not strings,
e.g. integer indices.

The `stateStructType` should be :class:`StructType` describing the schema of the
user-defined state. The value of the state will be presented as a tuple, as well as the
update should be performed with the tuple. The corresponding Python types for
:class:DataType are supported. Please refer to the page
https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python tab).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python tab).
https://spark.apache.org/docs/latest/sql-ref-datatypes.html (Python tab).


The size of each DataFrame in both the input and output can be arbitrary. The number of
DataFrames in both the input and output can also be arbitrary.
Comment on lines +262 to +263
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The size of each DataFrame in both the input and output can be arbitrary. The number of
DataFrames in both the input and output can also be arbitrary.
The size of each `pandas.DataFrame` in both the input and output can be arbitrary.
The number of DataFrames in both the input and output can also be arbitrary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can extract some notes from the description to Notes section. But no biggie.


.. versionadded:: 3.4.0

Parameters
----------
func : function
a Python native function to be called on every group. It should take parameters
(key, Iterator[`pandas.DataFrame`], state) and return Iterator[`pandas.DataFrame`].
Note that the type of the key is tuple and the type of the state is
:class:`pyspark.sql.streaming.state.GroupState`.
outputStructType : :class:`pyspark.sql.types.DataType` or str
the type of the output records. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you provide an example here of the string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the doc or here? All other PySpark method docs do not have example of this string.

Maybe we could have examples like other APIs do and provide DDL-formatted type string to compensate.

stateStructType : :class:`pyspark.sql.types.DataType` or str
the type of the user-defined state. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same - can you provide an example of the string

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same.

outputMode : str
the output mode of the function.
timeoutConf : str
timeout configuration for groups that do not receive data for a while. valid values
are defined in :class:`pyspark.sql.streaming.state.GroupStateTimeout`.

Examples
--------
>>> import pandas as pd # doctest: +SKIP
>>> from pyspark.sql.streaming.state import GroupStateTimeout
>>> def count_fn(key, pdf_iter, state):
... assert isinstance(state, GroupStateImpl)
... total_len = 0
... for pdf in pdf_iter:
... total_len += len(pdf)
... state.update((total_len,))
... yield pd.DataFrame({"id": [key[0]], "countAsString": [str(total_len)]})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
... yield pd.DataFrame({"id": [key[0]], "countAsString": [str(total_len)]})
... yield pd.DataFrame({"id": [key[0]], "countAsString": [str(total_len)]})
...

>>> df.groupby("id").applyInPandasWithState(
... count_fn, outputStructType="id long, countAsString string",
... stateStructType="len long", outputMode="Update",
... timeoutConf=GroupStateTimeout.NoTimeout) # doctest: +SKIP

Notes
-----
This function requires a full shuffle.

This API is experimental.
"""

from pyspark.sql import GroupedData
from pyspark.sql.functions import pandas_udf

assert isinstance(self, GroupedData)
assert timeoutConf in [
GroupStateTimeout.NoTimeout,
GroupStateTimeout.ProcessingTimeTimeout,
GroupStateTimeout.EventTimeTimeout,
]

if isinstance(outputStructType, str):
outputStructType = cast(StructType, _parse_datatype_string(outputStructType))
if isinstance(stateStructType, str):
stateStructType = cast(StructType, _parse_datatype_string(stateStructType))

udf = pandas_udf(
func, # type: ignore[call-overload]
returnType=outputStructType,
functionType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
)
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.applyInPandasWithState(
udf_column._jc.expr(),
self.session._jsparkSession.parseDataType(outputStructType.json()),
self.session._jsparkSession.parseDataType(stateStructType.json()),
outputMode,
timeoutConf,
)
return DataFrame(jdf, self.session)

def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
"""
Cogroups this group with another group so that we can run cogrouped operations.
Expand Down
Loading