Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jingz-db committed Nov 18, 2024
1 parent b7e6f59 commit 0c5ab3f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
27 changes: 14 additions & 13 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#
import itertools
import sys
from typing import Any, Iterator, List, Optional, Union, TYPE_CHECKING, cast
from typing import Any, Iterator, List, Optional, Union, Tuple, TYPE_CHECKING, cast
import warnings

from pyspark.errors import PySparkTypeError
Expand Down Expand Up @@ -502,6 +502,17 @@ def transformWithStateInPandas(
if isinstance(outputStructType, str):
outputStructType = cast(StructType, _parse_datatype_string(outputStructType))

def get_timestamps(
statefulProcessorApiClient: StatefulProcessorApiClient,
) -> Tuple[int, int]:
if timeMode != "none":
batch_timestamp = statefulProcessorApiClient.get_batch_timestamp()
watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp()
else:
batch_timestamp = -1
watermark_timestamp = -1
return batch_timestamp, watermark_timestamp

def handle_data_with_timers(
statefulProcessorApiClient: StatefulProcessorApiClient,
key: Any,
Expand Down Expand Up @@ -568,12 +579,7 @@ def transformWithStateUDF(
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED)
return iter([])

if timeMode != "none":
batch_timestamp = statefulProcessorApiClient.get_batch_timestamp()
watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp()
else:
batch_timestamp = -1
watermark_timestamp = -1
batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient)

result = handle_data_with_timers(
statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows
Expand Down Expand Up @@ -614,12 +620,7 @@ def transformWithStateWithInitStateUDF(
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED)
return iter([])

if timeMode != "none":
batch_timestamp = statefulProcessorApiClient.get_batch_timestamp()
watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp()
else:
batch_timestamp = -1
watermark_timestamp = -1
batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient)

# only process initial state if first batch and initial state is not None
if initialStates is not None:
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/streaming/stateful_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,12 +412,13 @@ def handleInputRows(
Timer value for the current batch that process the input rows.
Users can get the processing or event time timestamp from TimerValues.
"""
return iter([])
...

def handleExpiredTimer(
self, key: Any, timer_values: TimerValues, expired_timer_info: ExpiredTimerInfo
) -> Iterator["PandasDataFrameLike"]:
"""
Optional to implement. Will act return an empty iterator if not defined.
Function that will be invoked when a timer is fired for a given key. Users can choose to
evict state, register new timers and optionally provide output rows.
Expand Down

0 comments on commit 0c5ab3f

Please sign in to comment.