diff --git a/python/pyspark/sql/streaming/list_state_client.py b/python/pyspark/sql/streaming/list_state_client.py index d2152842819a..cb618d1a691b 100644 --- a/python/pyspark/sql/streaming/list_state_client.py +++ b/python/pyspark/sql/streaming/list_state_client.py @@ -14,10 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Dict, Iterator, List, Union, cast, Tuple +from typing import Dict, Iterator, List, Union, Tuple from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient -from pyspark.sql.types import StructType, TYPE_CHECKING, _parse_datatype_string +from pyspark.sql.types import StructType, TYPE_CHECKING from pyspark.errors import PySparkRuntimeError import uuid @@ -28,8 +28,16 @@ class ListStateClient: - def __init__(self, stateful_processor_api_client: StatefulProcessorApiClient) -> None: + def __init__( + self, + stateful_processor_api_client: StatefulProcessorApiClient, + schema: Union[StructType, str], + ) -> None: self._stateful_processor_api_client = stateful_processor_api_client + if isinstance(schema, str): + self.schema = self._stateful_processor_api_client._parse_string_schema(schema) + else: + self.schema = schema # A dictionary to store the mapping between list state name and a tuple of pandas DataFrame # and the index of the last row that was read. self.pandas_df_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {} @@ -105,12 +113,10 @@ def get(self, state_name: str, iterator_id: str) -> Tuple: pandas_row = pandas_df.iloc[index] return tuple(pandas_row) - def append_value(self, state_name: str, schema: Union[StructType, str], value: Tuple) -> None: + def append_value(self, state_name: str, value: Tuple) -> None: import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage - if isinstance(schema, str): - schema = cast(StructType, _parse_datatype_string(schema)) - bytes = self._stateful_processor_api_client._serialize_to_bytes(schema, value) + bytes = self._stateful_processor_api_client._serialize_to_bytes(self.schema, value) append_value_call = stateMessage.AppendValue(value=bytes) list_state_call = stateMessage.ListStateCall( stateName=state_name, appendValue=append_value_call @@ -125,13 +131,9 @@ def append_value(self, state_name: str, schema: Union[StructType, str], value: T # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error updating value state: " f"{response_message[1]}") - def append_list( - self, state_name: str, schema: Union[StructType, str], values: List[Tuple] - ) -> None: + def append_list(self, state_name: str, values: List[Tuple]) -> None: import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage - if isinstance(schema, str): - schema = cast(StructType, _parse_datatype_string(schema)) append_list_call = stateMessage.AppendList() list_state_call = stateMessage.ListStateCall( stateName=state_name, appendList=append_list_call @@ -141,18 +143,16 @@ def append_list( self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) - self._stateful_processor_api_client._send_arrow_state(schema, values) + self._stateful_processor_api_client._send_arrow_state(self.schema, values) response_message = self._stateful_processor_api_client._receive_proto_message() status = response_message[0] if status != 0: # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error updating value state: " f"{response_message[1]}") - def put(self, state_name: str, schema: Union[StructType, str], values: List[Tuple]) -> None: + def put(self, state_name: str, values: List[Tuple]) -> None: import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage - if isinstance(schema, str): - schema = cast(StructType, _parse_datatype_string(schema)) put_call = stateMessage.ListStatePut() list_state_call = stateMessage.ListStateCall(stateName=state_name, listStatePut=put_call) state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call) @@ -160,7 +160,7 @@ def put(self, state_name: str, schema: Union[StructType, str], values: List[Tupl self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) - self._stateful_processor_api_client._send_arrow_state(schema, values) + self._stateful_processor_api_client._send_arrow_state(self.schema, values) response_message = self._stateful_processor_api_client._receive_proto_message() status = response_message[0] if status != 0: diff --git a/python/pyspark/sql/streaming/map_state_client.py b/python/pyspark/sql/streaming/map_state_client.py index 6ec7448b4863..c4761ddd48a1 100644 --- a/python/pyspark/sql/streaming/map_state_client.py +++ b/python/pyspark/sql/streaming/map_state_client.py @@ -14,10 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Dict, Iterator, Union, cast, Tuple, Optional +from typing import Dict, Iterator, Union, Tuple, Optional from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient -from pyspark.sql.types import StructType, TYPE_CHECKING, _parse_datatype_string +from pyspark.sql.types import StructType, TYPE_CHECKING from pyspark.errors import PySparkRuntimeError import uuid @@ -36,11 +36,15 @@ def __init__( ) -> None: self._stateful_processor_api_client = stateful_processor_api_client if isinstance(user_key_schema, str): - self.user_key_schema = cast(StructType, _parse_datatype_string(user_key_schema)) + self.user_key_schema = self._stateful_processor_api_client._parse_string_schema( + user_key_schema + ) else: self.user_key_schema = user_key_schema if isinstance(value_schema, str): - self.value_schema = cast(StructType, _parse_datatype_string(value_schema)) + self.value_schema = self._stateful_processor_api_client._parse_string_schema( + value_schema + ) else: self.value_schema = value_schema # Dictionaries to store the mapping between iterator id and a tuple of pandas DataFrame diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py index 0a54690513a3..bcd8e0fc68f5 100644 --- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py +++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py @@ -40,7 +40,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n;org/apache/spark/sql/execution/streaming/StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xa0\x04\n\x0cStateRequest\x12\x18\n\x07version\x18\x01 \x01(\x05R\x07version\x12}\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00R\x15statefulProcessorCall\x12z\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00R\x14stateVariableRequest\x12\x8c\x01\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00R\x1aimplicitGroupingKeyRequest\x12\x62\n\x0ctimerRequest\x18\x05 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.TimerRequestH\x00R\x0ctimerRequestB\x08\n\x06method"i\n\rStateResponse\x12\x1e\n\nstatusCode\x18\x01 \x01(\x05R\nstatusCode\x12"\n\x0c\x65rrorMessage\x18\x02 \x01(\tR\x0c\x65rrorMessage\x12\x14\n\x05value\x18\x03 \x01(\x0cR\x05value"x\n\x1cStateResponseWithLongTypeVal\x12\x1e\n\nstatusCode\x18\x01 \x01(\x05R\nstatusCode\x12"\n\x0c\x65rrorMessage\x18\x02 \x01(\tR\x0c\x65rrorMessage\x12\x14\n\x05value\x18\x03 \x01(\x03R\x05value"\xa0\x05\n\x15StatefulProcessorCall\x12h\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00R\x0esetHandleState\x12h\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\rgetValueState\x12\x66\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\x0cgetListState\x12\x64\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\x0bgetMapState\x12o\n\x0etimerStateCall\x18\x05 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.TimerStateCallCommandH\x00R\x0etimerStateCall\x12j\n\x0e\x64\x65leteIfExists\x18\x06 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\x0e\x64\x65leteIfExistsB\x08\n\x06method"\xd5\x02\n\x14StateVariableRequest\x12h\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00R\x0evalueStateCall\x12\x65\n\rlistStateCall\x18\x02 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.ListStateCallH\x00R\rlistStateCall\x12\x62\n\x0cmapStateCall\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.MapStateCallH\x00R\x0cmapStateCallB\x08\n\x06method"\x83\x02\n\x1aImplicitGroupingKeyRequest\x12h\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00R\x0esetImplicitKey\x12q\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00R\x11removeImplicitKeyB\x08\n\x06method"\x81\x02\n\x0cTimerRequest\x12q\n\x11timerValueRequest\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.TimerValueRequestH\x00R\x11timerValueRequest\x12t\n\x12\x65xpiryTimerRequest\x18\x02 \x01(\x0b\x32\x42.org.apache.spark.sql.execution.streaming.state.ExpiryTimerRequestH\x00R\x12\x65xpiryTimerRequestB\x08\n\x06method"\xf6\x01\n\x11TimerValueRequest\x12s\n\x12getProcessingTimer\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.GetProcessingTimeH\x00R\x12getProcessingTimer\x12\x62\n\x0cgetWatermark\x18\x02 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.GetWatermarkH\x00R\x0cgetWatermarkB\x08\n\x06method"B\n\x12\x45xpiryTimerRequest\x12,\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03R\x11\x65xpiryTimestampMs"\x13\n\x11GetProcessingTime"\x0e\n\x0cGetWatermark"\xc7\x01\n\x10StateCallCommand\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12\x16\n\x06schema\x18\x02 \x01(\tR\x06schema\x12\x30\n\x13mapStateValueSchema\x18\x03 \x01(\tR\x13mapStateValueSchema\x12K\n\x03ttl\x18\x04 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfigR\x03ttl"\xa7\x02\n\x15TimerStateCallCommand\x12[\n\x08register\x18\x01 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.RegisterTimerH\x00R\x08register\x12U\n\x06\x64\x65lete\x18\x02 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.DeleteTimerH\x00R\x06\x64\x65lete\x12P\n\x04list\x18\x03 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.ListTimersH\x00R\x04listB\x08\n\x06method"\x92\x03\n\x0eValueStateCall\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12P\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00R\x06\x65xists\x12G\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00R\x03get\x12n\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00R\x10valueStateUpdate\x12M\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00R\x05\x63learB\x08\n\x06method"\xdf\x04\n\rListStateCall\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12P\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00R\x06\x65xists\x12\x62\n\x0clistStateGet\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStateGetH\x00R\x0clistStateGet\x12\x62\n\x0clistStatePut\x18\x04 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStatePutH\x00R\x0clistStatePut\x12_\n\x0b\x61ppendValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.AppendValueH\x00R\x0b\x61ppendValue\x12\\\n\nappendList\x18\x06 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.AppendListH\x00R\nappendList\x12M\n\x05\x63lear\x18\x07 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00R\x05\x63learB\x08\n\x06method"\xc2\x06\n\x0cMapStateCall\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12P\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00R\x06\x65xists\x12V\n\x08getValue\x18\x03 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.GetValueH\x00R\x08getValue\x12_\n\x0b\x63ontainsKey\x18\x04 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.ContainsKeyH\x00R\x0b\x63ontainsKey\x12_\n\x0bupdateValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.UpdateValueH\x00R\x0bupdateValue\x12V\n\x08iterator\x18\x06 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.IteratorH\x00R\x08iterator\x12J\n\x04keys\x18\x07 \x01(\x0b\x32\x34.org.apache.spark.sql.execution.streaming.state.KeysH\x00R\x04keys\x12P\n\x06values\x18\x08 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ValuesH\x00R\x06values\x12Y\n\tremoveKey\x18\t \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.RemoveKeyH\x00R\tremoveKey\x12M\n\x05\x63lear\x18\n \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00R\x05\x63learB\x08\n\x06method""\n\x0eSetImplicitKey\x12\x10\n\x03key\x18\x01 \x01(\x0cR\x03key"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"=\n\rRegisterTimer\x12,\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03R\x11\x65xpiryTimestampMs";\n\x0b\x44\x65leteTimer\x12,\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03R\x11\x65xpiryTimestampMs",\n\nListTimers\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"(\n\x10ValueStateUpdate\x12\x14\n\x05value\x18\x01 \x01(\x0cR\x05value"\x07\n\x05\x43lear".\n\x0cListStateGet\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"\x0e\n\x0cListStatePut"#\n\x0b\x41ppendValue\x12\x14\n\x05value\x18\x01 \x01(\x0cR\x05value"\x0c\n\nAppendList"$\n\x08GetValue\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey"\'\n\x0b\x43ontainsKey\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey"=\n\x0bUpdateValue\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey\x12\x14\n\x05value\x18\x02 \x01(\x0cR\x05value"*\n\x08Iterator\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"&\n\x04Keys\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"(\n\x06Values\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"%\n\tRemoveKey\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey"c\n\x0eSetHandleState\x12Q\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleStateR\x05state"+\n\tTTLConfig\x12\x1e\n\ndurationMs\x18\x01 \x01(\x05R\ndurationMs*`\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\x13\n\x0fTIMER_PROCESSED\x10\x03\x12\n\n\x06\x43LOSED\x10\x04\x62\x06proto3' + b'\n;org/apache/spark/sql/execution/streaming/StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\x84\x05\n\x0cStateRequest\x12\x18\n\x07version\x18\x01 \x01(\x05R\x07version\x12}\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00R\x15statefulProcessorCall\x12z\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00R\x14stateVariableRequest\x12\x8c\x01\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00R\x1aimplicitGroupingKeyRequest\x12\x62\n\x0ctimerRequest\x18\x05 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.TimerRequestH\x00R\x0ctimerRequest\x12\x62\n\x0cutilsRequest\x18\x06 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.UtilsRequestH\x00R\x0cutilsRequestB\x08\n\x06method"i\n\rStateResponse\x12\x1e\n\nstatusCode\x18\x01 \x01(\x05R\nstatusCode\x12"\n\x0c\x65rrorMessage\x18\x02 \x01(\tR\x0c\x65rrorMessage\x12\x14\n\x05value\x18\x03 \x01(\x0cR\x05value"x\n\x1cStateResponseWithLongTypeVal\x12\x1e\n\nstatusCode\x18\x01 \x01(\x05R\nstatusCode\x12"\n\x0c\x65rrorMessage\x18\x02 \x01(\tR\x0c\x65rrorMessage\x12\x14\n\x05value\x18\x03 \x01(\x03R\x05value"z\n\x1eStateResponseWithStringTypeVal\x12\x1e\n\nstatusCode\x18\x01 \x01(\x05R\nstatusCode\x12"\n\x0c\x65rrorMessage\x18\x02 \x01(\tR\x0c\x65rrorMessage\x12\x14\n\x05value\x18\x03 \x01(\tR\x05value"\xa0\x05\n\x15StatefulProcessorCall\x12h\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00R\x0esetHandleState\x12h\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\rgetValueState\x12\x66\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\x0cgetListState\x12\x64\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\x0bgetMapState\x12o\n\x0etimerStateCall\x18\x05 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.TimerStateCallCommandH\x00R\x0etimerStateCall\x12j\n\x0e\x64\x65leteIfExists\x18\x06 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00R\x0e\x64\x65leteIfExistsB\x08\n\x06method"\xd5\x02\n\x14StateVariableRequest\x12h\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00R\x0evalueStateCall\x12\x65\n\rlistStateCall\x18\x02 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.ListStateCallH\x00R\rlistStateCall\x12\x62\n\x0cmapStateCall\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.MapStateCallH\x00R\x0cmapStateCallB\x08\n\x06method"\x83\x02\n\x1aImplicitGroupingKeyRequest\x12h\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00R\x0esetImplicitKey\x12q\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00R\x11removeImplicitKeyB\x08\n\x06method"\x81\x02\n\x0cTimerRequest\x12q\n\x11timerValueRequest\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.TimerValueRequestH\x00R\x11timerValueRequest\x12t\n\x12\x65xpiryTimerRequest\x18\x02 \x01(\x0b\x32\x42.org.apache.spark.sql.execution.streaming.state.ExpiryTimerRequestH\x00R\x12\x65xpiryTimerRequestB\x08\n\x06method"\xf6\x01\n\x11TimerValueRequest\x12s\n\x12getProcessingTimer\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.GetProcessingTimeH\x00R\x12getProcessingTimer\x12\x62\n\x0cgetWatermark\x18\x02 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.GetWatermarkH\x00R\x0cgetWatermarkB\x08\n\x06method"B\n\x12\x45xpiryTimerRequest\x12,\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03R\x11\x65xpiryTimestampMs"\x13\n\x11GetProcessingTime"\x0e\n\x0cGetWatermark"\x8b\x01\n\x0cUtilsRequest\x12q\n\x11parseStringSchema\x18\x01 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.ParseStringSchemaH\x00R\x11parseStringSchemaB\x08\n\x06method"+\n\x11ParseStringSchema\x12\x16\n\x06schema\x18\x01 \x01(\tR\x06schema"\xc7\x01\n\x10StateCallCommand\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12\x16\n\x06schema\x18\x02 \x01(\tR\x06schema\x12\x30\n\x13mapStateValueSchema\x18\x03 \x01(\tR\x13mapStateValueSchema\x12K\n\x03ttl\x18\x04 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfigR\x03ttl"\xa7\x02\n\x15TimerStateCallCommand\x12[\n\x08register\x18\x01 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.RegisterTimerH\x00R\x08register\x12U\n\x06\x64\x65lete\x18\x02 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.DeleteTimerH\x00R\x06\x64\x65lete\x12P\n\x04list\x18\x03 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.ListTimersH\x00R\x04listB\x08\n\x06method"\x92\x03\n\x0eValueStateCall\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12P\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00R\x06\x65xists\x12G\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00R\x03get\x12n\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00R\x10valueStateUpdate\x12M\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00R\x05\x63learB\x08\n\x06method"\xdf\x04\n\rListStateCall\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12P\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00R\x06\x65xists\x12\x62\n\x0clistStateGet\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStateGetH\x00R\x0clistStateGet\x12\x62\n\x0clistStatePut\x18\x04 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStatePutH\x00R\x0clistStatePut\x12_\n\x0b\x61ppendValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.AppendValueH\x00R\x0b\x61ppendValue\x12\\\n\nappendList\x18\x06 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.AppendListH\x00R\nappendList\x12M\n\x05\x63lear\x18\x07 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00R\x05\x63learB\x08\n\x06method"\xc2\x06\n\x0cMapStateCall\x12\x1c\n\tstateName\x18\x01 \x01(\tR\tstateName\x12P\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00R\x06\x65xists\x12V\n\x08getValue\x18\x03 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.GetValueH\x00R\x08getValue\x12_\n\x0b\x63ontainsKey\x18\x04 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.ContainsKeyH\x00R\x0b\x63ontainsKey\x12_\n\x0bupdateValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.UpdateValueH\x00R\x0bupdateValue\x12V\n\x08iterator\x18\x06 \x01(\x0b\x32\x38.org.apache.spark.sql.execution.streaming.state.IteratorH\x00R\x08iterator\x12J\n\x04keys\x18\x07 \x01(\x0b\x32\x34.org.apache.spark.sql.execution.streaming.state.KeysH\x00R\x04keys\x12P\n\x06values\x18\x08 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ValuesH\x00R\x06values\x12Y\n\tremoveKey\x18\t \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.RemoveKeyH\x00R\tremoveKey\x12M\n\x05\x63lear\x18\n \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00R\x05\x63learB\x08\n\x06method""\n\x0eSetImplicitKey\x12\x10\n\x03key\x18\x01 \x01(\x0cR\x03key"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"=\n\rRegisterTimer\x12,\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03R\x11\x65xpiryTimestampMs";\n\x0b\x44\x65leteTimer\x12,\n\x11\x65xpiryTimestampMs\x18\x01 \x01(\x03R\x11\x65xpiryTimestampMs",\n\nListTimers\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"(\n\x10ValueStateUpdate\x12\x14\n\x05value\x18\x01 \x01(\x0cR\x05value"\x07\n\x05\x43lear".\n\x0cListStateGet\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"\x0e\n\x0cListStatePut"#\n\x0b\x41ppendValue\x12\x14\n\x05value\x18\x01 \x01(\x0cR\x05value"\x0c\n\nAppendList"$\n\x08GetValue\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey"\'\n\x0b\x43ontainsKey\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey"=\n\x0bUpdateValue\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey\x12\x14\n\x05value\x18\x02 \x01(\x0cR\x05value"*\n\x08Iterator\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"&\n\x04Keys\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"(\n\x06Values\x12\x1e\n\niteratorId\x18\x01 \x01(\tR\niteratorId"%\n\tRemoveKey\x12\x18\n\x07userKey\x18\x01 \x01(\x0cR\x07userKey"c\n\x0eSetHandleState\x12Q\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleStateR\x05state"+\n\tTTLConfig\x12\x1e\n\ndurationMs\x18\x01 \x01(\x05R\ndurationMs*`\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\x13\n\x0fTIMER_PROCESSED\x10\x03\x12\n\n\x06\x43LOSED\x10\x04\x62\x06proto3' ) _globals = globals() @@ -50,82 +50,88 @@ ) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals["_HANDLESTATE"]._serialized_start = 5997 - _globals["_HANDLESTATE"]._serialized_end = 6093 + _globals["_HANDLESTATE"]._serialized_start = 6408 + _globals["_HANDLESTATE"]._serialized_end = 6504 _globals["_STATEREQUEST"]._serialized_start = 112 - _globals["_STATEREQUEST"]._serialized_end = 656 - _globals["_STATERESPONSE"]._serialized_start = 658 - _globals["_STATERESPONSE"]._serialized_end = 763 - _globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_start = 765 - _globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_end = 885 - _globals["_STATEFULPROCESSORCALL"]._serialized_start = 888 - _globals["_STATEFULPROCESSORCALL"]._serialized_end = 1560 - _globals["_STATEVARIABLEREQUEST"]._serialized_start = 1563 - _globals["_STATEVARIABLEREQUEST"]._serialized_end = 1904 - _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1907 - _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 2166 - _globals["_TIMERREQUEST"]._serialized_start = 2169 - _globals["_TIMERREQUEST"]._serialized_end = 2426 - _globals["_TIMERVALUEREQUEST"]._serialized_start = 2429 - _globals["_TIMERVALUEREQUEST"]._serialized_end = 2675 - _globals["_EXPIRYTIMERREQUEST"]._serialized_start = 2677 - _globals["_EXPIRYTIMERREQUEST"]._serialized_end = 2743 - _globals["_GETPROCESSINGTIME"]._serialized_start = 2745 - _globals["_GETPROCESSINGTIME"]._serialized_end = 2764 - _globals["_GETWATERMARK"]._serialized_start = 2766 - _globals["_GETWATERMARK"]._serialized_end = 2780 - _globals["_STATECALLCOMMAND"]._serialized_start = 2783 - _globals["_STATECALLCOMMAND"]._serialized_end = 2982 - _globals["_TIMERSTATECALLCOMMAND"]._serialized_start = 2985 - _globals["_TIMERSTATECALLCOMMAND"]._serialized_end = 3280 - _globals["_VALUESTATECALL"]._serialized_start = 3283 - _globals["_VALUESTATECALL"]._serialized_end = 3685 - _globals["_LISTSTATECALL"]._serialized_start = 3688 - _globals["_LISTSTATECALL"]._serialized_end = 4295 - _globals["_MAPSTATECALL"]._serialized_start = 4298 - _globals["_MAPSTATECALL"]._serialized_end = 5132 - _globals["_SETIMPLICITKEY"]._serialized_start = 5134 - _globals["_SETIMPLICITKEY"]._serialized_end = 5168 - _globals["_REMOVEIMPLICITKEY"]._serialized_start = 5170 - _globals["_REMOVEIMPLICITKEY"]._serialized_end = 5189 - _globals["_EXISTS"]._serialized_start = 5191 - _globals["_EXISTS"]._serialized_end = 5199 - _globals["_GET"]._serialized_start = 5201 - _globals["_GET"]._serialized_end = 5206 - _globals["_REGISTERTIMER"]._serialized_start = 5208 - _globals["_REGISTERTIMER"]._serialized_end = 5269 - _globals["_DELETETIMER"]._serialized_start = 5271 - _globals["_DELETETIMER"]._serialized_end = 5330 - _globals["_LISTTIMERS"]._serialized_start = 5332 - _globals["_LISTTIMERS"]._serialized_end = 5376 - _globals["_VALUESTATEUPDATE"]._serialized_start = 5378 - _globals["_VALUESTATEUPDATE"]._serialized_end = 5418 - _globals["_CLEAR"]._serialized_start = 5420 - _globals["_CLEAR"]._serialized_end = 5427 - _globals["_LISTSTATEGET"]._serialized_start = 5429 - _globals["_LISTSTATEGET"]._serialized_end = 5475 - _globals["_LISTSTATEPUT"]._serialized_start = 5477 - _globals["_LISTSTATEPUT"]._serialized_end = 5491 - _globals["_APPENDVALUE"]._serialized_start = 5493 - _globals["_APPENDVALUE"]._serialized_end = 5528 - _globals["_APPENDLIST"]._serialized_start = 5530 - _globals["_APPENDLIST"]._serialized_end = 5542 - _globals["_GETVALUE"]._serialized_start = 5544 - _globals["_GETVALUE"]._serialized_end = 5580 - _globals["_CONTAINSKEY"]._serialized_start = 5582 - _globals["_CONTAINSKEY"]._serialized_end = 5621 - _globals["_UPDATEVALUE"]._serialized_start = 5623 - _globals["_UPDATEVALUE"]._serialized_end = 5684 - _globals["_ITERATOR"]._serialized_start = 5686 - _globals["_ITERATOR"]._serialized_end = 5728 - _globals["_KEYS"]._serialized_start = 5730 - _globals["_KEYS"]._serialized_end = 5768 - _globals["_VALUES"]._serialized_start = 5770 - _globals["_VALUES"]._serialized_end = 5810 - _globals["_REMOVEKEY"]._serialized_start = 5812 - _globals["_REMOVEKEY"]._serialized_end = 5849 - _globals["_SETHANDLESTATE"]._serialized_start = 5851 - _globals["_SETHANDLESTATE"]._serialized_end = 5950 - _globals["_TTLCONFIG"]._serialized_start = 5952 - _globals["_TTLCONFIG"]._serialized_end = 5995 + _globals["_STATEREQUEST"]._serialized_end = 756 + _globals["_STATERESPONSE"]._serialized_start = 758 + _globals["_STATERESPONSE"]._serialized_end = 863 + _globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_start = 865 + _globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_end = 985 + _globals["_STATERESPONSEWITHSTRINGTYPEVAL"]._serialized_start = 987 + _globals["_STATERESPONSEWITHSTRINGTYPEVAL"]._serialized_end = 1109 + _globals["_STATEFULPROCESSORCALL"]._serialized_start = 1112 + _globals["_STATEFULPROCESSORCALL"]._serialized_end = 1784 + _globals["_STATEVARIABLEREQUEST"]._serialized_start = 1787 + _globals["_STATEVARIABLEREQUEST"]._serialized_end = 2128 + _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 2131 + _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 2390 + _globals["_TIMERREQUEST"]._serialized_start = 2393 + _globals["_TIMERREQUEST"]._serialized_end = 2650 + _globals["_TIMERVALUEREQUEST"]._serialized_start = 2653 + _globals["_TIMERVALUEREQUEST"]._serialized_end = 2899 + _globals["_EXPIRYTIMERREQUEST"]._serialized_start = 2901 + _globals["_EXPIRYTIMERREQUEST"]._serialized_end = 2967 + _globals["_GETPROCESSINGTIME"]._serialized_start = 2969 + _globals["_GETPROCESSINGTIME"]._serialized_end = 2988 + _globals["_GETWATERMARK"]._serialized_start = 2990 + _globals["_GETWATERMARK"]._serialized_end = 3004 + _globals["_UTILSREQUEST"]._serialized_start = 3007 + _globals["_UTILSREQUEST"]._serialized_end = 3146 + _globals["_PARSESTRINGSCHEMA"]._serialized_start = 3148 + _globals["_PARSESTRINGSCHEMA"]._serialized_end = 3191 + _globals["_STATECALLCOMMAND"]._serialized_start = 3194 + _globals["_STATECALLCOMMAND"]._serialized_end = 3393 + _globals["_TIMERSTATECALLCOMMAND"]._serialized_start = 3396 + _globals["_TIMERSTATECALLCOMMAND"]._serialized_end = 3691 + _globals["_VALUESTATECALL"]._serialized_start = 3694 + _globals["_VALUESTATECALL"]._serialized_end = 4096 + _globals["_LISTSTATECALL"]._serialized_start = 4099 + _globals["_LISTSTATECALL"]._serialized_end = 4706 + _globals["_MAPSTATECALL"]._serialized_start = 4709 + _globals["_MAPSTATECALL"]._serialized_end = 5543 + _globals["_SETIMPLICITKEY"]._serialized_start = 5545 + _globals["_SETIMPLICITKEY"]._serialized_end = 5579 + _globals["_REMOVEIMPLICITKEY"]._serialized_start = 5581 + _globals["_REMOVEIMPLICITKEY"]._serialized_end = 5600 + _globals["_EXISTS"]._serialized_start = 5602 + _globals["_EXISTS"]._serialized_end = 5610 + _globals["_GET"]._serialized_start = 5612 + _globals["_GET"]._serialized_end = 5617 + _globals["_REGISTERTIMER"]._serialized_start = 5619 + _globals["_REGISTERTIMER"]._serialized_end = 5680 + _globals["_DELETETIMER"]._serialized_start = 5682 + _globals["_DELETETIMER"]._serialized_end = 5741 + _globals["_LISTTIMERS"]._serialized_start = 5743 + _globals["_LISTTIMERS"]._serialized_end = 5787 + _globals["_VALUESTATEUPDATE"]._serialized_start = 5789 + _globals["_VALUESTATEUPDATE"]._serialized_end = 5829 + _globals["_CLEAR"]._serialized_start = 5831 + _globals["_CLEAR"]._serialized_end = 5838 + _globals["_LISTSTATEGET"]._serialized_start = 5840 + _globals["_LISTSTATEGET"]._serialized_end = 5886 + _globals["_LISTSTATEPUT"]._serialized_start = 5888 + _globals["_LISTSTATEPUT"]._serialized_end = 5902 + _globals["_APPENDVALUE"]._serialized_start = 5904 + _globals["_APPENDVALUE"]._serialized_end = 5939 + _globals["_APPENDLIST"]._serialized_start = 5941 + _globals["_APPENDLIST"]._serialized_end = 5953 + _globals["_GETVALUE"]._serialized_start = 5955 + _globals["_GETVALUE"]._serialized_end = 5991 + _globals["_CONTAINSKEY"]._serialized_start = 5993 + _globals["_CONTAINSKEY"]._serialized_end = 6032 + _globals["_UPDATEVALUE"]._serialized_start = 6034 + _globals["_UPDATEVALUE"]._serialized_end = 6095 + _globals["_ITERATOR"]._serialized_start = 6097 + _globals["_ITERATOR"]._serialized_end = 6139 + _globals["_KEYS"]._serialized_start = 6141 + _globals["_KEYS"]._serialized_end = 6179 + _globals["_VALUES"]._serialized_start = 6181 + _globals["_VALUES"]._serialized_end = 6221 + _globals["_REMOVEKEY"]._serialized_start = 6223 + _globals["_REMOVEKEY"]._serialized_end = 6260 + _globals["_SETHANDLESTATE"]._serialized_start = 6262 + _globals["_SETHANDLESTATE"]._serialized_end = 6361 + _globals["_TTLCONFIG"]._serialized_start = 6363 + _globals["_TTLCONFIG"]._serialized_end = 6406 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi b/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi index 52f66928294c..03ede5d25b2b 100644 --- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi +++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi @@ -79,6 +79,7 @@ class StateRequest(google.protobuf.message.Message): STATEVARIABLEREQUEST_FIELD_NUMBER: builtins.int IMPLICITGROUPINGKEYREQUEST_FIELD_NUMBER: builtins.int TIMERREQUEST_FIELD_NUMBER: builtins.int + UTILSREQUEST_FIELD_NUMBER: builtins.int version: builtins.int @property def statefulProcessorCall(self) -> global___StatefulProcessorCall: ... @@ -88,6 +89,8 @@ class StateRequest(google.protobuf.message.Message): def implicitGroupingKeyRequest(self) -> global___ImplicitGroupingKeyRequest: ... @property def timerRequest(self) -> global___TimerRequest: ... + @property + def utilsRequest(self) -> global___UtilsRequest: ... def __init__( self, *, @@ -96,6 +99,7 @@ class StateRequest(google.protobuf.message.Message): stateVariableRequest: global___StateVariableRequest | None = ..., implicitGroupingKeyRequest: global___ImplicitGroupingKeyRequest | None = ..., timerRequest: global___TimerRequest | None = ..., + utilsRequest: global___UtilsRequest | None = ..., ) -> None: ... def HasField( self, @@ -110,6 +114,8 @@ class StateRequest(google.protobuf.message.Message): b"statefulProcessorCall", "timerRequest", b"timerRequest", + "utilsRequest", + b"utilsRequest", ], ) -> builtins.bool: ... def ClearField( @@ -125,6 +131,8 @@ class StateRequest(google.protobuf.message.Message): b"statefulProcessorCall", "timerRequest", b"timerRequest", + "utilsRequest", + b"utilsRequest", "version", b"version", ], @@ -137,6 +145,7 @@ class StateRequest(google.protobuf.message.Message): "stateVariableRequest", "implicitGroupingKeyRequest", "timerRequest", + "utilsRequest", ] | None ): ... @@ -193,6 +202,31 @@ class StateResponseWithLongTypeVal(google.protobuf.message.Message): global___StateResponseWithLongTypeVal = StateResponseWithLongTypeVal +class StateResponseWithStringTypeVal(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STATUSCODE_FIELD_NUMBER: builtins.int + ERRORMESSAGE_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + statusCode: builtins.int + errorMessage: builtins.str + value: builtins.str + def __init__( + self, + *, + statusCode: builtins.int = ..., + errorMessage: builtins.str = ..., + value: builtins.str = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "errorMessage", b"errorMessage", "statusCode", b"statusCode", "value", b"value" + ], + ) -> None: ... + +global___StateResponseWithStringTypeVal = StateResponseWithStringTypeVal + class StatefulProcessorCall(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -492,6 +526,49 @@ class GetWatermark(google.protobuf.message.Message): global___GetWatermark = GetWatermark +class UtilsRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + PARSESTRINGSCHEMA_FIELD_NUMBER: builtins.int + @property + def parseStringSchema(self) -> global___ParseStringSchema: ... + def __init__( + self, + *, + parseStringSchema: global___ParseStringSchema | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "method", b"method", "parseStringSchema", b"parseStringSchema" + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "method", b"method", "parseStringSchema", b"parseStringSchema" + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["method", b"method"] + ) -> typing_extensions.Literal["parseStringSchema"] | None: ... + +global___UtilsRequest = UtilsRequest + +class ParseStringSchema(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SCHEMA_FIELD_NUMBER: builtins.int + schema: builtins.str + def __init__( + self, + *, + schema: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["schema", b"schema"]) -> None: ... + +global___ParseStringSchema = ParseStringSchema + class StateCallCommand(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py index 9caa9304d6a8..b04bb955488a 100644 --- a/python/pyspark/sql/streaming/stateful_processor.py +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -45,12 +45,9 @@ class ValueState: .. versionadded:: 4.0.0 """ - def __init__( - self, value_state_client: ValueStateClient, state_name: str, schema: Union[StructType, str] - ) -> None: + def __init__(self, value_state_client: ValueStateClient, state_name: str) -> None: self._value_state_client = value_state_client self._state_name = state_name - self.schema = schema def exists(self) -> bool: """ @@ -68,7 +65,7 @@ def update(self, new_value: Tuple) -> None: """ Update the value of the state. """ - self._value_state_client.update(self._state_name, self.schema, new_value) + self._value_state_client.update(self._state_name, new_value) def clear(self) -> None: """ @@ -127,12 +124,9 @@ class ListState: .. versionadded:: 4.0.0 """ - def __init__( - self, list_state_client: ListStateClient, state_name: str, schema: Union[StructType, str] - ) -> None: + def __init__(self, list_state_client: ListStateClient, state_name: str) -> None: self._list_state_client = list_state_client self._state_name = state_name - self.schema = schema def exists(self) -> bool: """ @@ -150,19 +144,19 @@ def put(self, new_state: List[Tuple]) -> None: """ Update the values of the list state. """ - self._list_state_client.put(self._state_name, self.schema, new_state) + self._list_state_client.put(self._state_name, new_state) def append_value(self, new_state: Tuple) -> None: """ Append a new value to the list state. """ - self._list_state_client.append_value(self._state_name, self.schema, new_state) + self._list_state_client.append_value(self._state_name, new_state) def append_list(self, new_state: List[Tuple]) -> None: """ Append a list of new values to the list state. """ - self._list_state_client.append_list(self._state_name, self.schema, new_state) + self._list_state_client.append_list(self._state_name, new_state) def clear(self) -> None: """ @@ -275,7 +269,7 @@ def getValueState( If ttl is not specified the state will never expire. """ self.stateful_processor_api_client.get_value_state(state_name, schema, ttl_duration_ms) - return ValueState(ValueStateClient(self.stateful_processor_api_client), state_name, schema) + return ValueState(ValueStateClient(self.stateful_processor_api_client, schema), state_name) def getListState( self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int] = None @@ -299,7 +293,7 @@ def getListState( If ttl is not specified the state will never expire. """ self.stateful_processor_api_client.get_list_state(state_name, schema, ttl_duration_ms) - return ListState(ListStateClient(self.stateful_processor_api_client), state_name, schema) + return ListState(ListStateClient(self.stateful_processor_api_client, schema), state_name) def getMapState( self, diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index 53704188081c..79bb63d81d79 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -15,16 +15,16 @@ # limitations under the License. # from enum import Enum +import json import os import socket -from typing import Any, Dict, List, Union, Optional, cast, Tuple, Iterator +from typing import Any, Dict, List, Union, Optional, Tuple, Iterator from pyspark.serializers import write_int, read_int, UTF8Deserializer from pyspark.sql.pandas.serializers import ArrowStreamSerializer from pyspark.sql.types import ( StructType, TYPE_CHECKING, - _parse_datatype_string, Row, ) from pyspark.sql.pandas.types import convert_pandas_using_numpy_type @@ -129,7 +129,7 @@ def get_value_state( import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage if isinstance(schema, str): - schema = cast(StructType, _parse_datatype_string(schema)) + schema = self._parse_string_schema(schema) state_call_command = stateMessage.StateCallCommand() state_call_command.stateName = state_name @@ -152,7 +152,7 @@ def get_list_state( import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage if isinstance(schema, str): - schema = cast(StructType, _parse_datatype_string(schema)) + schema = self._parse_string_schema(schema) state_call_command = stateMessage.StateCallCommand() state_call_command.stateName = state_name @@ -290,9 +290,9 @@ def get_map_state( import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage if isinstance(user_key_schema, str): - user_key_schema = cast(StructType, _parse_datatype_string(user_key_schema)) + user_key_schema = self._parse_string_schema(user_key_schema) if isinstance(value_schema, str): - value_schema = cast(StructType, _parse_datatype_string(value_schema)) + value_schema = self._parse_string_schema(value_schema) state_call_command = stateMessage.StateCallCommand() state_call_command.stateName = state_name @@ -393,6 +393,15 @@ def _receive_proto_message_with_long_value(self) -> Tuple[int, str, int]: message.ParseFromString(bytes) return message.statusCode, message.errorMessage, message.value + def _receive_proto_message_with_string_value(self) -> Tuple[int, str, str]: + import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage + + length = read_int(self.sockfile) + bytes = self.sockfile.read(length) + message = stateMessage.StateResponseWithStringTypeVal() + message.ParseFromString(bytes) + return message.statusCode, message.errorMessage, message.value + def _receive_str(self) -> str: return self.utf8_deserializer.loads(self.sockfile) @@ -436,6 +445,24 @@ def _send_arrow_state(self, schema: StructType, state: List[Tuple]) -> None: def _read_arrow_state(self) -> Any: return self.serializer.load_stream(self.sockfile) + # Parse a string schema into a StructType schema. This method will perform an API call to + # JVM side to parse the schema string. + def _parse_string_schema(self, schema: str) -> StructType: + import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage + + parse_string_schema_call = stateMessage.ParseStringSchema(schema=schema) + utils_request = stateMessage.UtilsRequest(parseStringSchema=parse_string_schema_call) + message = stateMessage.StateRequest(utilsRequest=utils_request) + + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message_with_string_value() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error parsing string schema: " f"{response_message[1]}") + else: + return StructType.fromJson(json.loads(response_message[2])) + class ListTimerIterator: def __init__(self, stateful_processor_api_client: StatefulProcessorApiClient): diff --git a/python/pyspark/sql/streaming/value_state_client.py b/python/pyspark/sql/streaming/value_state_client.py index fd783af7931d..532a89cf92d2 100644 --- a/python/pyspark/sql/streaming/value_state_client.py +++ b/python/pyspark/sql/streaming/value_state_client.py @@ -14,18 +14,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Union, cast, Tuple, Optional +from typing import Union, Tuple, Optional from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient -from pyspark.sql.types import StructType, _parse_datatype_string +from pyspark.sql.types import StructType from pyspark.errors import PySparkRuntimeError __all__ = ["ValueStateClient"] class ValueStateClient: - def __init__(self, stateful_processor_api_client: StatefulProcessorApiClient) -> None: + def __init__( + self, + stateful_processor_api_client: StatefulProcessorApiClient, + schema: Union[StructType, str], + ) -> None: self._stateful_processor_api_client = stateful_processor_api_client + if isinstance(schema, str): + self.schema = self._stateful_processor_api_client._parse_string_schema(schema) + else: + self.schema = schema def exists(self, state_name: str) -> bool: import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage @@ -69,12 +77,10 @@ def get(self, state_name: str) -> Optional[Tuple]: # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error getting value state: " f"{response_message[1]}") - def update(self, state_name: str, schema: Union[StructType, str], value: Tuple) -> None: + def update(self, state_name: str, value: Tuple) -> None: import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage - if isinstance(schema, str): - schema = cast(StructType, _parse_datatype_string(schema)) - bytes = self._stateful_processor_api_client._serialize_to_bytes(schema, value) + bytes = self._stateful_processor_api_client._serialize_to_bytes(self.schema, value) update_call = stateMessage.ValueStateUpdate(value=bytes) value_state_call = stateMessage.ValueStateCall( stateName=state_name, valueStateUpdate=update_call diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 60f2c9348db3..5506c670c276 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -1013,8 +1013,9 @@ class SimpleStatefulProcessor(StatefulProcessor, unittest.TestCase): batch_id = 0 def init(self, handle: StatefulProcessorHandle) -> None: + # Test both string type and struct type schemas + self.num_violations_state = handle.getValueState("numViolations", "value int") state_schema = StructType([StructField("value", IntegerType(), True)]) - self.num_violations_state = handle.getValueState("numViolations", state_schema) self.temp_state = handle.getValueState("tempState", state_schema) handle.deleteIfExists("tempState") @@ -1205,9 +1206,8 @@ def init(self, handle: StatefulProcessorHandle) -> None: class MapStateProcessor(StatefulProcessor): def init(self, handle: StatefulProcessorHandle): - key_schema = StructType([StructField("name", StringType(), True)]) - value_schema = StructType([StructField("count", IntegerType(), True)]) - self.map_state = handle.getMapState("mapState", key_schema, value_schema) + # Test string type schemas + self.map_state = handle.getMapState("mapState", "name string", "count int") def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]: count = 0 diff --git a/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto b/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto index 4b0477290c8f..e69727a260a9 100644 --- a/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto +++ b/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto @@ -26,6 +26,7 @@ message StateRequest { StateVariableRequest stateVariableRequest = 3; ImplicitGroupingKeyRequest implicitGroupingKeyRequest = 4; TimerRequest timerRequest = 5; + UtilsRequest utilsRequest = 6; } } @@ -41,6 +42,12 @@ message StateResponseWithLongTypeVal { int64 value = 3; } +message StateResponseWithStringTypeVal { + int32 statusCode = 1; + string errorMessage = 2; + string value = 3; +} + message StatefulProcessorCall { oneof method { SetHandleState setHandleState = 1; @@ -91,6 +98,16 @@ message GetProcessingTime { message GetWatermark { } +message UtilsRequest { + oneof method { + ParseStringSchema parseStringSchema = 1; + } +} + +message ParseStringSchema { + string schema = 1; +} + message StateCallCommand { string stateName = 1; string schema = 2; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index 2957f4b38758..d03c75620df8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -33,8 +33,9 @@ import org.apache.spark.sql.api.python.PythonSQLUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState, StateVariableType} -import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, ListStateCall, MapStateCall, StatefulProcessorCall, StateRequest, StateResponse, StateResponseWithLongTypeVal, StateVariableRequest, TimerRequest, TimerStateCallCommand, TimerValueRequest, ValueStateCall} +import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, ListStateCall, MapStateCall, StatefulProcessorCall, StateRequest, StateResponse, StateResponseWithLongTypeVal, StateResponseWithStringTypeVal, StateVariableRequest, TimerRequest, TimerStateCallCommand, TimerValueRequest, UtilsRequest, ValueStateCall} import org.apache.spark.sql.streaming.{ListState, MapState, TTLConfig, ValueState} import org.apache.spark.sql.types.{BinaryType, LongType, StructField, StructType} import org.apache.spark.sql.util.ArrowUtils @@ -186,6 +187,19 @@ class TransformWithStateInPandasStateServer( handleStateVariableRequest(message.getStateVariableRequest) case StateRequest.MethodCase.TIMERREQUEST => handleTimerRequest(message.getTimerRequest) + case StateRequest.MethodCase.UTILSREQUEST => + handleUtilsRequest(message.getUtilsRequest) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private[sql] def handleUtilsRequest(message: UtilsRequest): Unit = { + message.getMethodCase match { + case UtilsRequest.MethodCase.PARSESTRINGSCHEMA => + val stringSchema = message.getParseStringSchema.getSchema + val schema = CatalystSqlParser.parseTableSchema(stringSchema) + sendResponseWithStringVal(0, null, schema.json) case _ => throw new IllegalArgumentException("Invalid method call") } @@ -690,6 +704,22 @@ class TransformWithStateInPandasStateServer( outputStream.write(responseMessageBytes) } + def sendResponseWithStringVal( + status: Int, + errorMessage: String = null, + stringVal: String): Unit = { + val responseMessageBuilder = StateResponseWithStringTypeVal.newBuilder().setStatusCode(status) + if (status != 0 && errorMessage != null) { + responseMessageBuilder.setErrorMessage(errorMessage) + } + responseMessageBuilder.setValue(stringVal) + val responseMessage = responseMessageBuilder.build() + val responseMessageBytes = responseMessage.toByteArray + val byteLength = responseMessageBytes.length + outputStream.writeInt(byteLength) + outputStream.write(responseMessageBytes) + } + def sendIteratorAsArrowBatches[T]( iter: Iterator[T], outputSchema: StructType, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala index e05264825f77..c3d4541bac29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.execution.streaming.{StatefulProcessorHandleImpl, StatefulProcessorHandleState} import org.apache.spark.sql.execution.streaming.state.StateMessage -import org.apache.spark.sql.execution.streaming.state.StateMessage.{AppendList, AppendValue, Clear, ContainsKey, DeleteTimer, Exists, ExpiryTimerRequest, Get, GetProcessingTime, GetValue, GetWatermark, HandleState, Keys, ListStateCall, ListStateGet, ListStatePut, ListTimers, MapStateCall, RegisterTimer, RemoveKey, SetHandleState, StateCallCommand, StatefulProcessorCall, TimerRequest, TimerStateCallCommand, TimerValueRequest, UpdateValue, Values, ValueStateCall, ValueStateUpdate} +import org.apache.spark.sql.execution.streaming.state.StateMessage.{AppendList, AppendValue, Clear, ContainsKey, DeleteTimer, Exists, ExpiryTimerRequest, Get, GetProcessingTime, GetValue, GetWatermark, HandleState, Keys, ListStateCall, ListStateGet, ListStatePut, ListTimers, MapStateCall, ParseStringSchema, RegisterTimer, RemoveKey, SetHandleState, StateCallCommand, StatefulProcessorCall, TimerRequest, TimerStateCallCommand, TimerValueRequest, UpdateValue, UtilsRequest, Values, ValueStateCall, ValueStateUpdate} import org.apache.spark.sql.streaming.{ListState, MapState, TTLConfig, ValueState} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -574,6 +574,16 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo verify(arrowStreamWriter).finalizeCurrentArrowBatch() } + test("utils request - parse string schema") { + val message = UtilsRequest.newBuilder().setParseStringSchema( + ParseStringSchema.newBuilder().setSchema( + "value int" + ).build() + ).build() + stateServer.handleUtilsRequest(message) + verify(outputStream).writeInt(argThat((x: Int) => x > 0)) + } + private def getIntegerRow(value: Int): Row = { new GenericRowWithSchema(Array(value), stateSchema) }