Skip to content

Commit

Permalink
Made the logic of handling timer and cleanup simpler
Browse files Browse the repository at this point in the history
  • Loading branch information
HeartSaVioR committed Nov 27, 2024
1 parent 0c5ab3f commit f8952b2
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 91 deletions.
112 changes: 58 additions & 54 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TimerValues,
)
from pyspark.sql.streaming.stateful_processor import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode
from pyspark.sql.types import StructType, _parse_datatype_string

if TYPE_CHECKING:
Expand Down Expand Up @@ -513,25 +514,28 @@ def get_timestamps(
watermark_timestamp = -1
return batch_timestamp, watermark_timestamp

def handle_data_with_timers(
def handle_data_rows(
statefulProcessorApiClient: StatefulProcessorApiClient,
key: Any,
batch_timestamp: int,
watermark_timestamp: int,
inputRows: Optional[Iterator["PandasDataFrameLike"]] = None,
) -> Iterator["PandasDataFrameLike"]:
statefulProcessorApiClient.set_implicit_key(key)
batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient)
# process with data rows
if inputRows is not None:
data_iter = statefulProcessor.handleInputRows(
key, inputRows, TimerValues(batch_timestamp, watermark_timestamp)
)
result_iter_list = [data_iter]
statefulProcessorApiClient.set_handle_state(
StatefulProcessorHandleState.DATA_PROCESSED
)
return data_iter
else:
result_iter_list = []
return iter([])

def handle_expired_timers(
statefulProcessorApiClient: StatefulProcessorApiClient,
) -> Iterator["PandasDataFrameLike"]:
result_iter_list = []

batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient)

if timeMode.lower() == "processingtime":
expiry_list_iter = statefulProcessorApiClient.get_expiry_timers_iterator(
Expand All @@ -547,19 +551,20 @@ def handle_data_with_timers(
# process with expiry timers, only timer related rows will be emitted
for expiry_list in expiry_list_iter:
for key_obj, expiry_timestamp in expiry_list:
statefulProcessorApiClient.set_implicit_key(key_obj)
result_iter_list.append(
statefulProcessor.handleExpiredTimer(
key_obj,
TimerValues(batch_timestamp, watermark_timestamp),
ExpiredTimerInfo(expiry_timestamp),
)
)
# TODO(SPARK-49603) set the handle state in the lazily initialized iterator
result = itertools.chain(*result_iter_list)
return result

return itertools.chain(*result_iter_list)

def transformWithStateUDF(
statefulProcessorApiClient: StatefulProcessorApiClient,
mode: TransformWithStateInPandasFuncMode,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
) -> Iterator["PandasDataFrameLike"]:
Expand All @@ -571,23 +576,24 @@ def transformWithStateUDF(
StatefulProcessorHandleState.INITIALIZED
)

# Key is None when we have processed all the input data from the worker and ready to
# proceed with the cleanup steps.
if key is None:
if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED)
result = handle_expired_timers(statefulProcessorApiClient)
return result
elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.TIMER_PROCESSED)
statefulProcessorApiClient.remove_implicit_key()
statefulProcessor.close()
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED)
return iter([])

batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient)

result = handle_data_with_timers(
statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows
)
return result
else:
# mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
result = handle_data_rows(statefulProcessorApiClient, key, inputRows)
return result

def transformWithStateWithInitStateUDF(
statefulProcessorApiClient: StatefulProcessorApiClient,
mode: TransformWithStateInPandasFuncMode,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
initialStates: Optional[Iterator["PandasDataFrameLike"]] = None,
Expand All @@ -612,45 +618,43 @@ def transformWithStateWithInitStateUDF(
StatefulProcessorHandleState.INITIALIZED
)

# Key is None when we have processed all the input data from the worker and ready to
# proceed with the cleanup steps.
if key is None:
if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED)
result = handle_expired_timers(statefulProcessorApiClient)
return result
elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
statefulProcessorApiClient.remove_implicit_key()
statefulProcessor.close()
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED)
return iter([])

batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient)

# only process initial state if first batch and initial state is not None
if initialStates is not None:
for cur_initial_state in initialStates:
statefulProcessorApiClient.set_implicit_key(key)
statefulProcessor.handleInitialState(
key, cur_initial_state, TimerValues(batch_timestamp, watermark_timestamp)
)

# if we don't have input rows for the given key but only have initial state
# for the grouping key, the inputRows iterator could be empty
input_rows_empty = False
try:
first = next(inputRows)
except StopIteration:
input_rows_empty = True
else:
inputRows = itertools.chain([first], inputRows)
# mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
batch_timestamp, watermark_timestamp = get_timestamps(statefulProcessorApiClient)

# only process initial state if first batch and initial state is not None
if initialStates is not None:
for cur_initial_state in initialStates:
statefulProcessorApiClient.set_implicit_key(key)
statefulProcessor.handleInitialState(
key, cur_initial_state, TimerValues(batch_timestamp, watermark_timestamp)
)

if not input_rows_empty:
result = handle_data_with_timers(
statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, inputRows
)
else:
# if the input rows is empty, we still need to handle the expired timers registered
# in the initial state
result = handle_data_with_timers(
statefulProcessorApiClient, key, batch_timestamp, watermark_timestamp, None
)
return result
# if we don't have input rows for the given key but only have initial state
# for the grouping key, the inputRows iterator could be empty
input_rows_empty = False
try:
first = next(inputRows)
except StopIteration:
input_rows_empty = True
else:
inputRows = itertools.chain([first], inputRows)

if not input_rows_empty:
result = handle_data_rows(statefulProcessorApiClient, key, inputRows)
else:
result = iter([])

return result

if isinstance(outputStructType, str):
outputStructType = cast(StructType, _parse_datatype_string(outputStructType))
Expand Down
14 changes: 11 additions & 3 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_create_converter_from_pandas,
_create_converter_to_pandas,
)
from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode
from pyspark.sql.types import (
DataType,
StringType,
Expand Down Expand Up @@ -1140,7 +1141,6 @@ def init_stream_yield_batches(batches):

return ArrowStreamSerializer.dump_stream(self, batches_to_write, stream)


class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
"""
Serializer used by Python worker to evaluate UDF for
Expand Down Expand Up @@ -1197,7 +1197,11 @@ def generate_data_batches(batches):
data_batches = generate_data_batches(_batches)

for k, g in groupby(data_batches, key=lambda x: x[0]):
yield (k, g)
yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)

yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)

yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None)

def dump_stream(self, iterator, stream):
"""
Expand Down Expand Up @@ -1281,4 +1285,8 @@ def flatten_columns(cur_batch, col_name):
data_batches = generate_data_batches(_batches)

for k, g in groupby(data_batches, key=lambda x: x[0]):
yield (k, g)
yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)

yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)

yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None)
26 changes: 26 additions & 0 deletions python/pyspark/sql/streaming/stateful_processor_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from enum import Enum

# This file places the utilities for transformWithStateInPandas; we have a separate file to avoid
# putting internal classes to the stateful_processor.py file which contains public APIs.

class TransformWithStateInPandasFuncMode(Enum):
PROCESS_DATA = 1
PROCESS_TIMER = 2
COMPLETE = 3
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def conf(cls):
"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider",
)
cfg.set("spark.sql.execution.arrow.transformWithStateInPandas.maxRecordsPerBatch", "2")
cfg.set("spark.sql.session.timeZone", "UTC")
return cfg

def _prepare_input_data(self, input_path, col1, col2):
Expand Down
68 changes: 34 additions & 34 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_deserialize_accumulator,
)
from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient
from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode
from pyspark.taskcontext import BarrierTaskContext, TaskContext
from pyspark.resource import ResourceInformation
from pyspark.util import PythonEvalType, local_connect_and_auth
Expand Down Expand Up @@ -493,36 +494,36 @@ def wrapped(key_series, value_series):


def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf):
def wrapped(stateful_processor_api_client, key, value_series_gen):
def wrapped(stateful_processor_api_client, mode, key, value_series_gen):
import pandas as pd

values = (pd.concat(x, axis=1) for x in value_series_gen)
result_iter = f(stateful_processor_api_client, key, values)
result_iter = f(stateful_processor_api_client, mode, key, values)

# TODO(SPARK-49100): add verification that elements in result_iter are
# indeed of type pd.DataFrame and confirm to assigned cols

return result_iter

return lambda p, k, v: [(wrapped(p, k, v), to_arrow_type(return_type))]
return lambda p, m, k, v: [(wrapped(p, m, k, v), to_arrow_type(return_type))]


def wrap_grouped_transform_with_state_pandas_init_state_udf(f, return_type, runner_conf):
def wrapped(stateful_processor_api_client, key, value_series_gen):
def wrapped(stateful_processor_api_client, mode, key, value_series_gen):
import pandas as pd

state_values_gen, init_states_gen = itertools.tee(value_series_gen, 2)
state_values = (df for x, _ in state_values_gen if not (df := pd.concat(x, axis=1)).empty)
init_states = (df for _, x in init_states_gen if not (df := pd.concat(x, axis=1)).empty)

result_iter = f(stateful_processor_api_client, key, state_values, init_states)
result_iter = f(stateful_processor_api_client, mode, key, state_values, init_states)

# TODO(SPARK-49100): add verification that elements in result_iter are
# indeed of type pd.DataFrame and confirm to assigned cols

return result_iter

return lambda p, k, v: [(wrapped(p, k, v), to_arrow_type(return_type))]
return lambda p, m, k, v: [(wrapped(p, m, k, v), to_arrow_type(return_type))]


def wrap_grouped_map_pandas_udf_with_state(f, return_type):
Expand Down Expand Up @@ -1697,18 +1698,22 @@ def mapper(a):
ser.key_offsets = parsed_offsets[0][0]
stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)

# Create function like this:
# mapper a: f([a[0]], [a[0], a[1]])
def mapper(a):
key = a[0]
mode = a[0]

def values_gen():
for x in a[1]:
retVal = [x[1][o] for o in parsed_offsets[0][1]]
yield retVal
if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
key = a[1]

# This must be generator comprehension - do not materialize.
return f(stateful_processor_api_client, key, values_gen())
def values_gen():
for x in a[2]:
retVal = [x[1][o] for o in parsed_offsets[0][1]]
yield retVal

# This must be generator comprehension - do not materialize.
return f(stateful_processor_api_client, mode, key, values_gen())
else:
# mode == PROCESS_TIMER or mode == COMPLETE
return f(stateful_processor_api_client, mode, None, iter([]))

elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF:
# We assume there is only one UDF here because grouped map doesn't
Expand All @@ -1731,16 +1736,22 @@ def values_gen():
stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)

def mapper(a):
key = a[0]
mode = a[0]

def values_gen():
for x in a[1]:
retVal = [x[1][o] for o in parsed_offsets[0][1]]
initVal = [x[2][o] for o in parsed_offsets[1][1]]
yield retVal, initVal
if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
key = a[1]

# This must be generator comprehension - do not materialize.
return f(stateful_processor_api_client, key, values_gen())
def values_gen():
for x in a[2]:
retVal = [x[1][o] for o in parsed_offsets[0][1]]
initVal = [x[2][o] for o in parsed_offsets[1][1]]
yield retVal, initVal

# This must be generator comprehension - do not materialize.
return f(stateful_processor_api_client, mode, key, values_gen())
else:
# mode == PROCESS_TIMER or mode == COMPLETE
return f(stateful_processor_api_client, mode, None, iter([]))

elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
import pyarrow as pa
Expand Down Expand Up @@ -1958,17 +1969,6 @@ def process():
try:
serializer.dump_stream(out_iter, outfile)
finally:
# Sending a signal to TransformWithState UDF to perform proper cleanup steps.
if (
eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF
or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF
):
# Sending key as None to indicate that process() has finished.
end_iter = func(split_index, iter([(None, None)]))
# Need to materialize the iterator to trigger the cleanup steps, nothing needs
# to be done here.
for _ in end_iter:
pass
if hasattr(out_iter, "close"):
out_iter.close()

Expand Down

0 comments on commit f8952b2

Please sign in to comment.