From e4638c888cf367f467ccffa245eb58f55a923e80 Mon Sep 17 00:00:00 2001 From: bogao007 Date: Fri, 8 Nov 2024 17:57:36 +0900 Subject: [PATCH] [SPARK-49899][PYTHON][SS] Support deleteIfExists for TransformWithStateInPandas ### What changes were proposed in this pull request? - Support deleteIfExists for TransformWithStateInPandas. - Added `close()` support for StatefulProcessor. ### Why are the changes needed? Add parity to TransformWithStateInPandas for functionalities we support in TransformWithState ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? New unit test. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48373 from bogao007/delete-if-exists. Authored-by: bogao007 Signed-off-by: Jungtaek Lim --- python/pyspark/sql/pandas/group_ops.py | 16 ++ .../sql/streaming/proto/StateMessage_pb2.py | 144 +++++++++--------- .../sql/streaming/proto/StateMessage_pb2.pyi | 12 +- .../sql/streaming/stateful_processor.py | 6 + .../stateful_processor_api_client.py | 15 ++ .../test_pandas_transform_with_state.py | 11 +- python/pyspark/worker.py | 11 ++ .../execution/streaming/StateMessage.proto | 1 + ...ransformWithStateInPandasStateServer.scala | 11 ++ ...ormWithStateInPandasStateServerSuite.scala | 11 ++ 10 files changed, 163 insertions(+), 75 deletions(-) diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 856af2abfe680..56efe0676c08f 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -565,6 +565,14 @@ def transformWithStateUDF( StatefulProcessorHandleState.INITIALIZED ) + # Key is None when we have processed all the input data from the worker and ready to + # proceed with the cleanup steps. + if key is None: + statefulProcessorApiClient.remove_implicit_key() + statefulProcessor.close() + statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) + return iter([]) + result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows) return result @@ -594,6 +602,14 @@ def transformWithStateWithInitStateUDF( StatefulProcessorHandleState.INITIALIZED ) + # Key is None when we have processed all the input data from the worker and ready to + # proceed with the cleanup steps. + if key is None: + statefulProcessorApiClient.remove_implicit_key() + statefulProcessor.close() + statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED) + return iter([]) + # only process initial state if first batch and initial state is not None if initialStates is not None: for cur_initial_state in initialStates: diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py index aeb195ca10ba7..46bed10c45588 100644 --- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py +++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py @@ -35,7 +35,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xbf\x03\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00\x12T\n\x0ctimerRequest\x18\x05 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.TimerRequestH\x00\x42\x08\n\x06method"H\n\rStateResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x0c"W\n\x1cStateResponseWithLongTypeVal\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x03"\xea\x03\n\x15StatefulProcessorCall\x12X\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00\x12Y\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12X\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12W\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12_\n\x0etimerStateCall\x18\x05 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.TimerStateCallCommandH\x00\x42\x08\n\x06method"\xa8\x02\n\x14StateVariableRequest\x12X\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00\x12V\n\rlistStateCall\x18\x02 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.ListStateCallH\x00\x12T\n\x0cmapStateCall\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.MapStateCallH\x00\x42\x08\n\x06method"\xe0\x01\n\x1aImplicitGroupingKeyRequest\x12X\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00\x12^\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00\x42\x08\n\x06method"\xda\x01\n\x0cTimerRequest\x12^\n\x11timerValueRequest\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.TimerValueRequestH\x00\x12`\n\x12\x65xpiryTimerRequest\x18\x02 \x01(\x0b\x32\x42.org.apache.spark.sql.execution.streaming.state.ExpiryTimerRequestH\x00\x42\x08\n\x06method"\xd4\x01\n\x11TimerValueRequest\x12_\n\x12getProcessingTimer\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.GetProcessingTimeH\x00\x12T\n\x0cgetWatermark\x18\x02 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.GetWatermarkH\x00\x42\x08\n\x06method"/\n\x12\x45xpiryTimerRequest\x12\x19\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03"\x13\n\x11GetProcessingTime"\x0e\n\x0cGetWatermark"\x9a\x01\n\x10StateCallCommand\x12\x11\n\tstateName\x18\x01 \x01(\t\x12\x0e\n\x06schema\x18\x02 \x01(\t\x12\x1b\n\x13mapStateValueSchema\x18\x03 \x01(\t\x12\x46\n\x03ttl\x18\x04 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfig"\x8f\x02\n\x15TimerStateCallCommand\x12Q\n\x08register\x18\x01 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.RegisterTimerH\x00\x12M\n\x06\x64\x65lete\x18\x02 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.DeleteTimerH\x00\x12J\n\x04list\x18\x03 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.ListTimersH\x00\x42\x08\n\x06method"\xe1\x02\n\x0eValueStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12\x42\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00\x12\\\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00\x12\x46\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x90\x04\n\rListStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12T\n\x0clistStateGet\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStateGetH\x00\x12T\n\x0clistStatePut\x18\x04 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStatePutH\x00\x12R\n\x0b\x61ppendValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.AppendValueH\x00\x12P\n\nappendList\x18\x06 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.AppendListH\x00\x12\x46\n\x05\x63lear\x18\x07 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\xe1\x05\n\x0cMapStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12L\n\x08getValue\x18\x03 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.GetValueH\x00\x12R\n\x0b\x63ontainsKey\x18\x04 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.ContainsKeyH\x00\x12R\n\x0bupdateValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.UpdateValueH\x00\x12L\n\x08iterator\x18\x06 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.IteratorH\x00\x12\x44\n\x04keys\x18\x07 \x01(\x0b\x32\x34.org.apache.spark.sql.execution.streaming.state.KeysH\x00\x12H\n\x06values\x18\x08 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ValuesH\x00\x12N\n\tremoveKey\x18\t \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.RemoveKeyH\x00\x12\x46\n\x05\x63lear\x18\n \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x1d\n\x0eSetImplicitKey\x12\x0b\n\x03key\x18\x01 \x01(\x0c"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"*\n\rRegisterTimer\x12\x19\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03"(\n\x0b\x44\x65leteTimer\x12\x19\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03" \n\nListTimers\x12\x12\n\niteratorId\x18\x01 \x01(\t"!\n\x10ValueStateUpdate\x12\r\n\x05value\x18\x01 \x01(\x0c"\x07\n\x05\x43lear""\n\x0cListStateGet\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x0e\n\x0cListStatePut"\x1c\n\x0b\x41ppendValue\x12\r\n\x05value\x18\x01 \x01(\x0c"\x0c\n\nAppendList"\x1b\n\x08GetValue\x12\x0f\n\x07userKey\x18\x01 \x01(\x0c"\x1e\n\x0b\x43ontainsKey\x12\x0f\n\x07userKey\x18\x01 \x01(\x0c"-\n\x0bUpdateValue\x12\x0f\n\x07userKey\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c"\x1e\n\x08Iterator\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x1a\n\x04Keys\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x1c\n\x06Values\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x1c\n\tRemoveKey\x12\x0f\n\x07userKey\x18\x01 \x01(\x0c"\\\n\x0eSetHandleState\x12J\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleState"\x1f\n\tTTLConfig\x12\x12\n\ndurationMs\x18\x01 \x01(\x05*`\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\x13\n\x0fTIMER_PROCESSED\x10\x03\x12\n\n\x06\x43LOSED\x10\x04\x62\x06proto3' # noqa: E501 + b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xbf\x03\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00\x12T\n\x0ctimerRequest\x18\x05 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.TimerRequestH\x00\x42\x08\n\x06method"H\n\rStateResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x0c"W\n\x1cStateResponseWithLongTypeVal\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x03"\xc6\x04\n\x15StatefulProcessorCall\x12X\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00\x12Y\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12X\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12W\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12_\n\x0etimerStateCall\x18\x05 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.TimerStateCallCommandH\x00\x12Z\n\x0e\x64\x65leteIfExists\x18\x06 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x42\x08\n\x06method"\xa8\x02\n\x14StateVariableRequest\x12X\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00\x12V\n\rlistStateCall\x18\x02 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.ListStateCallH\x00\x12T\n\x0cmapStateCall\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.MapStateCallH\x00\x42\x08\n\x06method"\xe0\x01\n\x1aImplicitGroupingKeyRequest\x12X\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00\x12^\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00\x42\x08\n\x06method"\xda\x01\n\x0cTimerRequest\x12^\n\x11timerValueRequest\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.TimerValueRequestH\x00\x12`\n\x12\x65xpiryTimerRequest\x18\x02 \x01(\x0b\x32\x42.org.apache.spark.sql.execution.streaming.state.ExpiryTimerRequestH\x00\x42\x08\n\x06method"\xd4\x01\n\x11TimerValueRequest\x12_\n\x12getProcessingTimer\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.GetProcessingTimeH\x00\x12T\n\x0cgetWatermark\x18\x02 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.GetWatermarkH\x00\x42\x08\n\x06method"/\n\x12\x45xpiryTimerRequest\x12\x19\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03"\x13\n\x11GetProcessingTime"\x0e\n\x0cGetWatermark"\x9a\x01\n\x10StateCallCommand\x12\x11\n\tstateName\x18\x01 \x01(\t\x12\x0e\n\x06schema\x18\x02 \x01(\t\x12\x1b\n\x13mapStateValueSchema\x18\x03 \x01(\t\x12\x46\n\x03ttl\x18\x04 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfig"\x8f\x02\n\x15TimerStateCallCommand\x12Q\n\x08register\x18\x01 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.RegisterTimerH\x00\x12M\n\x06\x64\x65lete\x18\x02 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.DeleteTimerH\x00\x12J\n\x04list\x18\x03 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.ListTimersH\x00\x42\x08\n\x06method"\xe1\x02\n\x0eValueStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12\x42\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00\x12\\\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00\x12\x46\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x90\x04\n\rListStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12T\n\x0clistStateGet\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStateGetH\x00\x12T\n\x0clistStatePut\x18\x04 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStatePutH\x00\x12R\n\x0b\x61ppendValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.AppendValueH\x00\x12P\n\nappendList\x18\x06 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.AppendListH\x00\x12\x46\n\x05\x63lear\x18\x07 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\xe1\x05\n\x0cMapStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12L\n\x08getValue\x18\x03 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.GetValueH\x00\x12R\n\x0b\x63ontainsKey\x18\x04 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.ContainsKeyH\x00\x12R\n\x0bupdateValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.UpdateValueH\x00\x12L\n\x08iterator\x18\x06 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.IteratorH\x00\x12\x44\n\x04keys\x18\x07 \x01(\x0b\x32\x34.org.apache.spark.sql.execution.streaming.state.KeysH\x00\x12H\n\x06values\x18\x08 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ValuesH\x00\x12N\n\tremoveKey\x18\t \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.RemoveKeyH\x00\x12\x46\n\x05\x63lear\x18\n \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x1d\n\x0eSetImplicitKey\x12\x0b\n\x03key\x18\x01 \x01(\x0c"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"*\n\rRegisterTimer\x12\x19\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03"(\n\x0b\x44\x65leteTimer\x12\x19\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03" \n\nListTimers\x12\x12\n\niteratorId\x18\x01 \x01(\t"!\n\x10ValueStateUpdate\x12\r\n\x05value\x18\x01 \x01(\x0c"\x07\n\x05\x43lear""\n\x0cListStateGet\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x0e\n\x0cListStatePut"\x1c\n\x0b\x41ppendValue\x12\r\n\x05value\x18\x01 \x01(\x0c"\x0c\n\nAppendList"\x1b\n\x08GetValue\x12\x0f\n\x07userKey\x18\x01 \x01(\x0c"\x1e\n\x0b\x43ontainsKey\x12\x0f\n\x07userKey\x18\x01 \x01(\x0c"-\n\x0bUpdateValue\x12\x0f\n\x07userKey\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c"\x1e\n\x08Iterator\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x1a\n\x04Keys\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x1c\n\x06Values\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x1c\n\tRemoveKey\x12\x0f\n\x07userKey\x18\x01 \x01(\x0c"\\\n\x0eSetHandleState\x12J\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleState"\x1f\n\tTTLConfig\x12\x12\n\ndurationMs\x18\x01 \x01(\x05*`\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\x13\n\x0fTIMER_PROCESSED\x10\x03\x12\n\n\x06\x43LOSED\x10\x04\x62\x06proto3' # noqa: E501 ) _globals = globals() @@ -43,8 +43,8 @@ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "StateMessage_pb2", _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals["_HANDLESTATE"]._serialized_start = 4966 - _globals["_HANDLESTATE"]._serialized_end = 5062 + _globals["_HANDLESTATE"]._serialized_start = 5058 + _globals["_HANDLESTATE"]._serialized_end = 5154 _globals["_STATEREQUEST"]._serialized_start = 71 _globals["_STATEREQUEST"]._serialized_end = 518 _globals["_STATERESPONSE"]._serialized_start = 520 @@ -52,73 +52,73 @@ _globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_start = 594 _globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_end = 681 _globals["_STATEFULPROCESSORCALL"]._serialized_start = 684 - _globals["_STATEFULPROCESSORCALL"]._serialized_end = 1174 - _globals["_STATEVARIABLEREQUEST"]._serialized_start = 1177 - _globals["_STATEVARIABLEREQUEST"]._serialized_end = 1473 - _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1476 - _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1700 - _globals["_TIMERREQUEST"]._serialized_start = 1703 - _globals["_TIMERREQUEST"]._serialized_end = 1921 - _globals["_TIMERVALUEREQUEST"]._serialized_start = 1924 - _globals["_TIMERVALUEREQUEST"]._serialized_end = 2136 - _globals["_EXPIRYTIMERREQUEST"]._serialized_start = 2138 - _globals["_EXPIRYTIMERREQUEST"]._serialized_end = 2185 - _globals["_GETPROCESSINGTIME"]._serialized_start = 2187 - _globals["_GETPROCESSINGTIME"]._serialized_end = 2206 - _globals["_GETWATERMARK"]._serialized_start = 2208 - _globals["_GETWATERMARK"]._serialized_end = 2222 - _globals["_STATECALLCOMMAND"]._serialized_start = 2225 - _globals["_STATECALLCOMMAND"]._serialized_end = 2379 - _globals["_TIMERSTATECALLCOMMAND"]._serialized_start = 2382 - _globals["_TIMERSTATECALLCOMMAND"]._serialized_end = 2653 - _globals["_VALUESTATECALL"]._serialized_start = 2656 - _globals["_VALUESTATECALL"]._serialized_end = 3009 - _globals["_LISTSTATECALL"]._serialized_start = 3012 - _globals["_LISTSTATECALL"]._serialized_end = 3540 - _globals["_MAPSTATECALL"]._serialized_start = 3543 - _globals["_MAPSTATECALL"]._serialized_end = 4280 - _globals["_SETIMPLICITKEY"]._serialized_start = 4282 - _globals["_SETIMPLICITKEY"]._serialized_end = 4311 - _globals["_REMOVEIMPLICITKEY"]._serialized_start = 4313 - _globals["_REMOVEIMPLICITKEY"]._serialized_end = 4332 - _globals["_EXISTS"]._serialized_start = 4334 - _globals["_EXISTS"]._serialized_end = 4342 - _globals["_GET"]._serialized_start = 4344 - _globals["_GET"]._serialized_end = 4349 - _globals["_REGISTERTIMER"]._serialized_start = 4351 - _globals["_REGISTERTIMER"]._serialized_end = 4393 - _globals["_DELETETIMER"]._serialized_start = 4395 - _globals["_DELETETIMER"]._serialized_end = 4435 - _globals["_LISTTIMERS"]._serialized_start = 4437 - _globals["_LISTTIMERS"]._serialized_end = 4469 - _globals["_VALUESTATEUPDATE"]._serialized_start = 4471 - _globals["_VALUESTATEUPDATE"]._serialized_end = 4504 - _globals["_CLEAR"]._serialized_start = 4506 - _globals["_CLEAR"]._serialized_end = 4513 - _globals["_LISTSTATEGET"]._serialized_start = 4515 - _globals["_LISTSTATEGET"]._serialized_end = 4549 - _globals["_LISTSTATEPUT"]._serialized_start = 4551 - _globals["_LISTSTATEPUT"]._serialized_end = 4565 - _globals["_APPENDVALUE"]._serialized_start = 4567 - _globals["_APPENDVALUE"]._serialized_end = 4595 - _globals["_APPENDLIST"]._serialized_start = 4597 - _globals["_APPENDLIST"]._serialized_end = 4609 - _globals["_GETVALUE"]._serialized_start = 4611 - _globals["_GETVALUE"]._serialized_end = 4638 - _globals["_CONTAINSKEY"]._serialized_start = 4640 - _globals["_CONTAINSKEY"]._serialized_end = 4670 - _globals["_UPDATEVALUE"]._serialized_start = 4672 - _globals["_UPDATEVALUE"]._serialized_end = 4717 - _globals["_ITERATOR"]._serialized_start = 4719 - _globals["_ITERATOR"]._serialized_end = 4749 - _globals["_KEYS"]._serialized_start = 4751 - _globals["_KEYS"]._serialized_end = 4777 - _globals["_VALUES"]._serialized_start = 4779 - _globals["_VALUES"]._serialized_end = 4807 - _globals["_REMOVEKEY"]._serialized_start = 4809 - _globals["_REMOVEKEY"]._serialized_end = 4837 - _globals["_SETHANDLESTATE"]._serialized_start = 4839 - _globals["_SETHANDLESTATE"]._serialized_end = 4931 - _globals["_TTLCONFIG"]._serialized_start = 4933 - _globals["_TTLCONFIG"]._serialized_end = 4964 + _globals["_STATEFULPROCESSORCALL"]._serialized_end = 1266 + _globals["_STATEVARIABLEREQUEST"]._serialized_start = 1269 + _globals["_STATEVARIABLEREQUEST"]._serialized_end = 1565 + _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1568 + _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1792 + _globals["_TIMERREQUEST"]._serialized_start = 1795 + _globals["_TIMERREQUEST"]._serialized_end = 2013 + _globals["_TIMERVALUEREQUEST"]._serialized_start = 2016 + _globals["_TIMERVALUEREQUEST"]._serialized_end = 2228 + _globals["_EXPIRYTIMERREQUEST"]._serialized_start = 2230 + _globals["_EXPIRYTIMERREQUEST"]._serialized_end = 2277 + _globals["_GETPROCESSINGTIME"]._serialized_start = 2279 + _globals["_GETPROCESSINGTIME"]._serialized_end = 2298 + _globals["_GETWATERMARK"]._serialized_start = 2300 + _globals["_GETWATERMARK"]._serialized_end = 2314 + _globals["_STATECALLCOMMAND"]._serialized_start = 2317 + _globals["_STATECALLCOMMAND"]._serialized_end = 2471 + _globals["_TIMERSTATECALLCOMMAND"]._serialized_start = 2474 + _globals["_TIMERSTATECALLCOMMAND"]._serialized_end = 2745 + _globals["_VALUESTATECALL"]._serialized_start = 2748 + _globals["_VALUESTATECALL"]._serialized_end = 3101 + _globals["_LISTSTATECALL"]._serialized_start = 3104 + _globals["_LISTSTATECALL"]._serialized_end = 3632 + _globals["_MAPSTATECALL"]._serialized_start = 3635 + _globals["_MAPSTATECALL"]._serialized_end = 4372 + _globals["_SETIMPLICITKEY"]._serialized_start = 4374 + _globals["_SETIMPLICITKEY"]._serialized_end = 4403 + _globals["_REMOVEIMPLICITKEY"]._serialized_start = 4405 + _globals["_REMOVEIMPLICITKEY"]._serialized_end = 4424 + _globals["_EXISTS"]._serialized_start = 4426 + _globals["_EXISTS"]._serialized_end = 4434 + _globals["_GET"]._serialized_start = 4436 + _globals["_GET"]._serialized_end = 4441 + _globals["_REGISTERTIMER"]._serialized_start = 4443 + _globals["_REGISTERTIMER"]._serialized_end = 4485 + _globals["_DELETETIMER"]._serialized_start = 4487 + _globals["_DELETETIMER"]._serialized_end = 4527 + _globals["_LISTTIMERS"]._serialized_start = 4529 + _globals["_LISTTIMERS"]._serialized_end = 4561 + _globals["_VALUESTATEUPDATE"]._serialized_start = 4563 + _globals["_VALUESTATEUPDATE"]._serialized_end = 4596 + _globals["_CLEAR"]._serialized_start = 4598 + _globals["_CLEAR"]._serialized_end = 4605 + _globals["_LISTSTATEGET"]._serialized_start = 4607 + _globals["_LISTSTATEGET"]._serialized_end = 4641 + _globals["_LISTSTATEPUT"]._serialized_start = 4643 + _globals["_LISTSTATEPUT"]._serialized_end = 4657 + _globals["_APPENDVALUE"]._serialized_start = 4659 + _globals["_APPENDVALUE"]._serialized_end = 4687 + _globals["_APPENDLIST"]._serialized_start = 4689 + _globals["_APPENDLIST"]._serialized_end = 4701 + _globals["_GETVALUE"]._serialized_start = 4703 + _globals["_GETVALUE"]._serialized_end = 4730 + _globals["_CONTAINSKEY"]._serialized_start = 4732 + _globals["_CONTAINSKEY"]._serialized_end = 4762 + _globals["_UPDATEVALUE"]._serialized_start = 4764 + _globals["_UPDATEVALUE"]._serialized_end = 4809 + _globals["_ITERATOR"]._serialized_start = 4811 + _globals["_ITERATOR"]._serialized_end = 4841 + _globals["_KEYS"]._serialized_start = 4843 + _globals["_KEYS"]._serialized_end = 4869 + _globals["_VALUES"]._serialized_start = 4871 + _globals["_VALUES"]._serialized_end = 4899 + _globals["_REMOVEKEY"]._serialized_start = 4901 + _globals["_REMOVEKEY"]._serialized_end = 4929 + _globals["_SETHANDLESTATE"]._serialized_start = 4931 + _globals["_SETHANDLESTATE"]._serialized_end = 5023 + _globals["_TTLCONFIG"]._serialized_start = 5025 + _globals["_TTLCONFIG"]._serialized_end = 5056 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi b/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi index ff525ee136a45..bc5138f52281c 100644 --- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi +++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi @@ -98,17 +98,26 @@ class StateResponseWithLongTypeVal(_message.Message): ) -> None: ... class StatefulProcessorCall(_message.Message): - __slots__ = ("setHandleState", "getValueState", "getListState", "getMapState", "timerStateCall") + __slots__ = ( + "setHandleState", + "getValueState", + "getListState", + "getMapState", + "timerStateCall", + "deleteIfExists", + ) SETHANDLESTATE_FIELD_NUMBER: _ClassVar[int] GETVALUESTATE_FIELD_NUMBER: _ClassVar[int] GETLISTSTATE_FIELD_NUMBER: _ClassVar[int] GETMAPSTATE_FIELD_NUMBER: _ClassVar[int] TIMERSTATECALL_FIELD_NUMBER: _ClassVar[int] + DELETEIFEXISTS_FIELD_NUMBER: _ClassVar[int] setHandleState: SetHandleState getValueState: StateCallCommand getListState: StateCallCommand getMapState: StateCallCommand timerStateCall: TimerStateCallCommand + deleteIfExists: StateCallCommand def __init__( self, setHandleState: _Optional[_Union[SetHandleState, _Mapping]] = ..., @@ -116,6 +125,7 @@ class StatefulProcessorCall(_message.Message): getListState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., getMapState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., timerStateCall: _Optional[_Union[TimerStateCallCommand, _Mapping]] = ..., + deleteIfExists: _Optional[_Union[StateCallCommand, _Mapping]] = ..., ) -> None: ... class StateVariableRequest(_message.Message): diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py index 266fb2d3e735e..20078c215bace 100644 --- a/python/pyspark/sql/streaming/stateful_processor.py +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -363,6 +363,12 @@ def listTimers(self) -> Iterator[int]: """ return ListTimerIterator(self.stateful_processor_api_client) + def deleteIfExists(self, state_name: str) -> None: + """ + Function to delete and purge state variable if defined previously + """ + self.stateful_processor_api_client.delete_if_exists(state_name) + class StatefulProcessor(ABC): """ diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index ce3bae0a7c91d..353f75e267962 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -338,6 +338,21 @@ def get_map_state( # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error initializing map state: " f"{response_message[1]}") + def delete_if_exists(self, state_name: str) -> None: + import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage + + state_call_command = stateMessage.StateCallCommand() + state_call_command.stateName = state_name + call = stateMessage.StatefulProcessorCall(deleteIfExists=state_call_command) + message = stateMessage.StateRequest(statefulProcessorCall=call) + + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error deleting state: " f"{response_message[1]}") + def _send_proto_message(self, message: bytes) -> None: # Writing zero here to indicate message version. This allows us to evolve the message # format or even changing the message protocol in the future. diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index bb7c31119ef8a..46aad4b6bc60d 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -25,6 +25,7 @@ from typing import cast from pyspark import SparkConf +from pyspark.errors import PySparkRuntimeError from pyspark.sql.functions import split from pyspark.sql.types import ( StringType, @@ -835,17 +836,21 @@ def close(self) -> None: pass -class SimpleStatefulProcessor(StatefulProcessor): +class SimpleStatefulProcessor(StatefulProcessor, unittest.TestCase): dict = {0: {"0": 1, "1": 2}, 1: {"0": 4, "1": 3}} batch_id = 0 def init(self, handle: StatefulProcessorHandle) -> None: state_schema = StructType([StructField("value", IntegerType(), True)]) self.num_violations_state = handle.getValueState("numViolations", state_schema) + self.temp_state = handle.getValueState("tempState", state_schema) + handle.deleteIfExists("tempState") def handleInputRows( self, key, rows, timer_values, expired_timer_info ) -> Iterator[pd.DataFrame]: + with self.assertRaisesRegex(PySparkRuntimeError, "Error checking value state exists"): + self.temp_state.exists() new_violations = 0 count = 0 key_str = key[0] @@ -873,10 +878,12 @@ def close(self) -> None: # A stateful processor that inherit all behavior of SimpleStatefulProcessor except that it use # ttl state with a large timeout. -class SimpleTTLStatefulProcessor(SimpleStatefulProcessor): +class SimpleTTLStatefulProcessor(SimpleStatefulProcessor, unittest.TestCase): def init(self, handle: StatefulProcessorHandle) -> None: state_schema = StructType([StructField("value", IntegerType(), True)]) self.num_violations_state = handle.getValueState("numViolations", state_schema, 30000) + self.temp_state = handle.getValueState("tempState", state_schema) + handle.deleteIfExists("tempState") class TTLStatefulProcessor(StatefulProcessor): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 10418f0487c94..04f95e9f52648 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1958,6 +1958,17 @@ def process(): try: serializer.dump_stream(out_iter, outfile) finally: + # Sending a signal to TransformWithState UDF to perform proper cleanup steps. + if ( + eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF + or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF + ): + # Sending key as None to indicate that process() has finished. + end_iter = func(split_index, iter([(None, None)])) + # Need to materialize the iterator to trigger the cleanup steps, nothing needs + # to be done here. + for _ in end_iter: + pass if hasattr(out_iter, "close"): out_iter.close() diff --git a/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto b/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto index 544cd3b10b1ca..4b0477290c8f7 100644 --- a/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto +++ b/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto @@ -48,6 +48,7 @@ message StatefulProcessorCall { StateCallCommand getListState = 3; StateCallCommand getMapState = 4; TimerStateCallCommand timerStateCall = 5; + StateCallCommand deleteIfExists = 6; } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index 5f3ebd87e75e4..0373c8607ff2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -343,6 +343,17 @@ class TransformWithStateInPandasStateServer( case _ => throw new IllegalArgumentException("Invalid timer state method call") } + case StatefulProcessorCall.MethodCase.DELETEIFEXISTS => + val stateName = message.getDeleteIfExists.getStateName + statefulProcessorHandle.deleteIfExists(stateName) + if (valueStates.contains(stateName)) { + valueStates.remove(stateName) + } else if (listStates.contains(stateName)) { + listStates.remove(stateName) + } else if (mapStates.contains(stateName)) { + mapStates.remove(stateName) + } + sendResponse(0) case _ => throw new IllegalArgumentException("Invalid method call") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala index 3925c3d62da37..e05264825f773 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala @@ -186,6 +186,17 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo } } + test("delete if exists") { + val stateCallCommandBuilder = StateCallCommand.newBuilder() + .setStateName("stateName") + val message = StatefulProcessorCall + .newBuilder() + .setDeleteIfExists(stateCallCommandBuilder.build()) + .build() + stateServer.handleStatefulProcessorCall(message) + verify(statefulProcessorHandle).deleteIfExists(any[String]) + } + test("value state exists") { val message = ValueStateCall.newBuilder().setStateName(stateName) .setExists(Exists.newBuilder().build()).build()