Skip to content
This repository has been archived by the owner on Dec 4, 2024. It is now read-only.

Commit

Permalink
[SPARK-49899][PYTHON][SS] Support deleteIfExists for TransformWithSta…
Browse files Browse the repository at this point in the history
…teInPandas

### What changes were proposed in this pull request?

- Support deleteIfExists for TransformWithStateInPandas.
- Added `close()` support for StatefulProcessor.

### Why are the changes needed?

Add parity to TransformWithStateInPandas for functionalities we support in TransformWithState

### Does this PR introduce _any_ user-facing change?

Yes

### How was this patch tested?

New unit test.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#48373 from bogao007/delete-if-exists.

Authored-by: bogao007 <bo.gao@databricks.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
bogao007 authored and HeartSaVioR committed Nov 8, 2024
1 parent 60acd2f commit e4638c8
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 75 deletions.
16 changes: 16 additions & 0 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,14 @@ def transformWithStateUDF(
StatefulProcessorHandleState.INITIALIZED
)

# Key is None when we have processed all the input data from the worker and ready to
# proceed with the cleanup steps.
if key is None:
statefulProcessorApiClient.remove_implicit_key()
statefulProcessor.close()
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED)
return iter([])

result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows)
return result

Expand Down Expand Up @@ -594,6 +602,14 @@ def transformWithStateWithInitStateUDF(
StatefulProcessorHandleState.INITIALIZED
)

# Key is None when we have processed all the input data from the worker and ready to
# proceed with the cleanup steps.
if key is None:
statefulProcessorApiClient.remove_implicit_key()
statefulProcessor.close()
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED)
return iter([])

# only process initial state if first batch and initial state is not None
if initialStates is not None:
for cur_initial_state in initialStates:
Expand Down
144 changes: 72 additions & 72 deletions python/pyspark/sql/streaming/proto/StateMessage_pb2.py

Large diffs are not rendered by default.

12 changes: 11 additions & 1 deletion python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,34 @@ class StateResponseWithLongTypeVal(_message.Message):
) -> None: ...

class StatefulProcessorCall(_message.Message):
__slots__ = ("setHandleState", "getValueState", "getListState", "getMapState", "timerStateCall")
__slots__ = (
"setHandleState",
"getValueState",
"getListState",
"getMapState",
"timerStateCall",
"deleteIfExists",
)
SETHANDLESTATE_FIELD_NUMBER: _ClassVar[int]
GETVALUESTATE_FIELD_NUMBER: _ClassVar[int]
GETLISTSTATE_FIELD_NUMBER: _ClassVar[int]
GETMAPSTATE_FIELD_NUMBER: _ClassVar[int]
TIMERSTATECALL_FIELD_NUMBER: _ClassVar[int]
DELETEIFEXISTS_FIELD_NUMBER: _ClassVar[int]
setHandleState: SetHandleState
getValueState: StateCallCommand
getListState: StateCallCommand
getMapState: StateCallCommand
timerStateCall: TimerStateCallCommand
deleteIfExists: StateCallCommand
def __init__(
self,
setHandleState: _Optional[_Union[SetHandleState, _Mapping]] = ...,
getValueState: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
getListState: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
getMapState: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
timerStateCall: _Optional[_Union[TimerStateCallCommand, _Mapping]] = ...,
deleteIfExists: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
) -> None: ...

class StateVariableRequest(_message.Message):
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/sql/streaming/stateful_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,12 @@ def listTimers(self) -> Iterator[int]:
"""
return ListTimerIterator(self.stateful_processor_api_client)

def deleteIfExists(self, state_name: str) -> None:
"""
Function to delete and purge state variable if defined previously
"""
self.stateful_processor_api_client.delete_if_exists(state_name)


class StatefulProcessor(ABC):
"""
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/sql/streaming/stateful_processor_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,21 @@ def get_map_state(
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error initializing map state: " f"{response_message[1]}")

def delete_if_exists(self, state_name: str) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage

state_call_command = stateMessage.StateCallCommand()
state_call_command.stateName = state_name
call = stateMessage.StatefulProcessorCall(deleteIfExists=state_call_command)
message = stateMessage.StateRequest(statefulProcessorCall=call)

self._send_proto_message(message.SerializeToString())
response_message = self._receive_proto_message()
status = response_message[0]
if status != 0:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error deleting state: " f"{response_message[1]}")

def _send_proto_message(self, message: bytes) -> None:
# Writing zero here to indicate message version. This allows us to evolve the message
# format or even changing the message protocol in the future.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing import cast

from pyspark import SparkConf
from pyspark.errors import PySparkRuntimeError
from pyspark.sql.functions import split
from pyspark.sql.types import (
StringType,
Expand Down Expand Up @@ -835,17 +836,21 @@ def close(self) -> None:
pass


class SimpleStatefulProcessor(StatefulProcessor):
class SimpleStatefulProcessor(StatefulProcessor, unittest.TestCase):
dict = {0: {"0": 1, "1": 2}, 1: {"0": 4, "1": 3}}
batch_id = 0

def init(self, handle: StatefulProcessorHandle) -> None:
state_schema = StructType([StructField("value", IntegerType(), True)])
self.num_violations_state = handle.getValueState("numViolations", state_schema)
self.temp_state = handle.getValueState("tempState", state_schema)
handle.deleteIfExists("tempState")

def handleInputRows(
self, key, rows, timer_values, expired_timer_info
) -> Iterator[pd.DataFrame]:
with self.assertRaisesRegex(PySparkRuntimeError, "Error checking value state exists"):
self.temp_state.exists()
new_violations = 0
count = 0
key_str = key[0]
Expand Down Expand Up @@ -873,10 +878,12 @@ def close(self) -> None:

# A stateful processor that inherit all behavior of SimpleStatefulProcessor except that it use
# ttl state with a large timeout.
class SimpleTTLStatefulProcessor(SimpleStatefulProcessor):
class SimpleTTLStatefulProcessor(SimpleStatefulProcessor, unittest.TestCase):
def init(self, handle: StatefulProcessorHandle) -> None:
state_schema = StructType([StructField("value", IntegerType(), True)])
self.num_violations_state = handle.getValueState("numViolations", state_schema, 30000)
self.temp_state = handle.getValueState("tempState", state_schema)
handle.deleteIfExists("tempState")


class TTLStatefulProcessor(StatefulProcessor):
Expand Down
11 changes: 11 additions & 0 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1958,6 +1958,17 @@ def process():
try:
serializer.dump_stream(out_iter, outfile)
finally:
# Sending a signal to TransformWithState UDF to perform proper cleanup steps.
if (
eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF
or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF
):
# Sending key as None to indicate that process() has finished.
end_iter = func(split_index, iter([(None, None)]))
# Need to materialize the iterator to trigger the cleanup steps, nothing needs
# to be done here.
for _ in end_iter:
pass
if hasattr(out_iter, "close"):
out_iter.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ message StatefulProcessorCall {
StateCallCommand getListState = 3;
StateCallCommand getMapState = 4;
TimerStateCallCommand timerStateCall = 5;
StateCallCommand deleteIfExists = 6;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,17 @@ class TransformWithStateInPandasStateServer(
case _ =>
throw new IllegalArgumentException("Invalid timer state method call")
}
case StatefulProcessorCall.MethodCase.DELETEIFEXISTS =>
val stateName = message.getDeleteIfExists.getStateName
statefulProcessorHandle.deleteIfExists(stateName)
if (valueStates.contains(stateName)) {
valueStates.remove(stateName)
} else if (listStates.contains(stateName)) {
listStates.remove(stateName)
} else if (mapStates.contains(stateName)) {
mapStates.remove(stateName)
}
sendResponse(0)
case _ =>
throw new IllegalArgumentException("Invalid method call")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,17 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
}
}

test("delete if exists") {
val stateCallCommandBuilder = StateCallCommand.newBuilder()
.setStateName("stateName")
val message = StatefulProcessorCall
.newBuilder()
.setDeleteIfExists(stateCallCommandBuilder.build())
.build()
stateServer.handleStatefulProcessorCall(message)
verify(statefulProcessorHandle).deleteIfExists(any[String])
}

test("value state exists") {
val message = ValueStateCall.newBuilder().setStateName(stateName)
.setExists(Exists.newBuilder().build()).build()
Expand Down

0 comments on commit e4638c8

Please sign in to comment.