Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions python/pyspark/sql/streaming/list_state_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Iterator, List, Union, cast, Tuple
from typing import Dict, Iterator, List, Union, Tuple

from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient
from pyspark.sql.types import StructType, TYPE_CHECKING, _parse_datatype_string
from pyspark.sql.types import StructType, TYPE_CHECKING
from pyspark.errors import PySparkRuntimeError
import uuid

Expand All @@ -28,8 +28,16 @@


class ListStateClient:
def __init__(self, stateful_processor_api_client: StatefulProcessorApiClient) -> None:
def __init__(
self,
stateful_processor_api_client: StatefulProcessorApiClient,
schema: Union[StructType, str],
) -> None:
self._stateful_processor_api_client = stateful_processor_api_client
if isinstance(schema, str):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make schema part of the class constructor to avoid multiple API calls for parsing.

self.schema = self._stateful_processor_api_client._parse_string_schema(schema)
else:
self.schema = schema
# A dictionary to store the mapping between list state name and a tuple of pandas DataFrame
# and the index of the last row that was read.
self.pandas_df_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
Expand Down Expand Up @@ -105,12 +113,10 @@ def get(self, state_name: str, iterator_id: str) -> Tuple:
pandas_row = pandas_df.iloc[index]
return tuple(pandas_row)

def append_value(self, state_name: str, schema: Union[StructType, str], value: Tuple) -> None:
def append_value(self, state_name: str, value: Tuple) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage

if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))
bytes = self._stateful_processor_api_client._serialize_to_bytes(schema, value)
bytes = self._stateful_processor_api_client._serialize_to_bytes(self.schema, value)
append_value_call = stateMessage.AppendValue(value=bytes)
list_state_call = stateMessage.ListStateCall(
stateName=state_name, appendValue=append_value_call
Expand All @@ -125,13 +131,9 @@ def append_value(self, state_name: str, schema: Union[StructType, str], value: T
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error updating value state: " f"{response_message[1]}")

def append_list(
self, state_name: str, schema: Union[StructType, str], values: List[Tuple]
) -> None:
def append_list(self, state_name: str, values: List[Tuple]) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage

if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))
append_list_call = stateMessage.AppendList()
list_state_call = stateMessage.ListStateCall(
stateName=state_name, appendList=append_list_call
Expand All @@ -141,26 +143,24 @@ def append_list(

self._stateful_processor_api_client._send_proto_message(message.SerializeToString())

self._stateful_processor_api_client._send_arrow_state(schema, values)
self._stateful_processor_api_client._send_arrow_state(self.schema, values)
response_message = self._stateful_processor_api_client._receive_proto_message()
status = response_message[0]
if status != 0:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error updating value state: " f"{response_message[1]}")

def put(self, state_name: str, schema: Union[StructType, str], values: List[Tuple]) -> None:
def put(self, state_name: str, values: List[Tuple]) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage

if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))
put_call = stateMessage.ListStatePut()
list_state_call = stateMessage.ListStateCall(stateName=state_name, listStatePut=put_call)
state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call)
message = stateMessage.StateRequest(stateVariableRequest=state_variable_request)

self._stateful_processor_api_client._send_proto_message(message.SerializeToString())

self._stateful_processor_api_client._send_arrow_state(schema, values)
self._stateful_processor_api_client._send_arrow_state(self.schema, values)
response_message = self._stateful_processor_api_client._receive_proto_message()
status = response_message[0]
if status != 0:
Expand Down
12 changes: 8 additions & 4 deletions python/pyspark/sql/streaming/map_state_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Iterator, Union, cast, Tuple, Optional
from typing import Dict, Iterator, Union, Tuple, Optional

from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient
from pyspark.sql.types import StructType, TYPE_CHECKING, _parse_datatype_string
from pyspark.sql.types import StructType, TYPE_CHECKING
from pyspark.errors import PySparkRuntimeError
import uuid

Expand All @@ -36,11 +36,15 @@ def __init__(
) -> None:
self._stateful_processor_api_client = stateful_processor_api_client
if isinstance(user_key_schema, str):
self.user_key_schema = cast(StructType, _parse_datatype_string(user_key_schema))
self.user_key_schema = self._stateful_processor_api_client._parse_string_schema(
user_key_schema
)
else:
self.user_key_schema = user_key_schema
if isinstance(value_schema, str):
self.value_schema = cast(StructType, _parse_datatype_string(value_schema))
self.value_schema = self._stateful_processor_api_client._parse_string_schema(
value_schema
)
else:
self.value_schema = value_schema
# Dictionaries to store the mapping between iterator id and a tuple of pandas DataFrame
Expand Down
162 changes: 84 additions & 78 deletions python/pyspark/sql/streaming/proto/StateMessage_pb2.py

Large diffs are not rendered by default.

77 changes: 77 additions & 0 deletions python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class StateRequest(google.protobuf.message.Message):
STATEVARIABLEREQUEST_FIELD_NUMBER: builtins.int
IMPLICITGROUPINGKEYREQUEST_FIELD_NUMBER: builtins.int
TIMERREQUEST_FIELD_NUMBER: builtins.int
UTILSREQUEST_FIELD_NUMBER: builtins.int
version: builtins.int
@property
def statefulProcessorCall(self) -> global___StatefulProcessorCall: ...
Expand All @@ -88,6 +89,8 @@ class StateRequest(google.protobuf.message.Message):
def implicitGroupingKeyRequest(self) -> global___ImplicitGroupingKeyRequest: ...
@property
def timerRequest(self) -> global___TimerRequest: ...
@property
def utilsRequest(self) -> global___UtilsRequest: ...
def __init__(
self,
*,
Expand All @@ -96,6 +99,7 @@ class StateRequest(google.protobuf.message.Message):
stateVariableRequest: global___StateVariableRequest | None = ...,
implicitGroupingKeyRequest: global___ImplicitGroupingKeyRequest | None = ...,
timerRequest: global___TimerRequest | None = ...,
utilsRequest: global___UtilsRequest | None = ...,
) -> None: ...
def HasField(
self,
Expand All @@ -110,6 +114,8 @@ class StateRequest(google.protobuf.message.Message):
b"statefulProcessorCall",
"timerRequest",
b"timerRequest",
"utilsRequest",
b"utilsRequest",
],
) -> builtins.bool: ...
def ClearField(
Expand All @@ -125,6 +131,8 @@ class StateRequest(google.protobuf.message.Message):
b"statefulProcessorCall",
"timerRequest",
b"timerRequest",
"utilsRequest",
b"utilsRequest",
"version",
b"version",
],
Expand All @@ -137,6 +145,7 @@ class StateRequest(google.protobuf.message.Message):
"stateVariableRequest",
"implicitGroupingKeyRequest",
"timerRequest",
"utilsRequest",
]
| None
): ...
Expand Down Expand Up @@ -193,6 +202,31 @@ class StateResponseWithLongTypeVal(google.protobuf.message.Message):

global___StateResponseWithLongTypeVal = StateResponseWithLongTypeVal

class StateResponseWithStringTypeVal(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

STATUSCODE_FIELD_NUMBER: builtins.int
ERRORMESSAGE_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
statusCode: builtins.int
errorMessage: builtins.str
value: builtins.str
def __init__(
self,
*,
statusCode: builtins.int = ...,
errorMessage: builtins.str = ...,
value: builtins.str = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"errorMessage", b"errorMessage", "statusCode", b"statusCode", "value", b"value"
],
) -> None: ...

global___StateResponseWithStringTypeVal = StateResponseWithStringTypeVal

class StatefulProcessorCall(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

Expand Down Expand Up @@ -492,6 +526,49 @@ class GetWatermark(google.protobuf.message.Message):

global___GetWatermark = GetWatermark

class UtilsRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

PARSESTRINGSCHEMA_FIELD_NUMBER: builtins.int
@property
def parseStringSchema(self) -> global___ParseStringSchema: ...
def __init__(
self,
*,
parseStringSchema: global___ParseStringSchema | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"method", b"method", "parseStringSchema", b"parseStringSchema"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"method", b"method", "parseStringSchema", b"parseStringSchema"
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["method", b"method"]
) -> typing_extensions.Literal["parseStringSchema"] | None: ...

global___UtilsRequest = UtilsRequest

class ParseStringSchema(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

SCHEMA_FIELD_NUMBER: builtins.int
schema: builtins.str
def __init__(
self,
*,
schema: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["schema", b"schema"]) -> None: ...

global___ParseStringSchema = ParseStringSchema

class StateCallCommand(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

Expand Down
22 changes: 8 additions & 14 deletions python/pyspark/sql/streaming/stateful_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,9 @@ class ValueState:
.. versionadded:: 4.0.0
"""

def __init__(
self, value_state_client: ValueStateClient, state_name: str, schema: Union[StructType, str]
) -> None:
def __init__(self, value_state_client: ValueStateClient, state_name: str) -> None:
self._value_state_client = value_state_client
self._state_name = state_name
self.schema = schema

def exists(self) -> bool:
"""
Expand All @@ -68,7 +65,7 @@ def update(self, new_value: Tuple) -> None:
"""
Update the value of the state.
"""
self._value_state_client.update(self._state_name, self.schema, new_value)
self._value_state_client.update(self._state_name, new_value)

def clear(self) -> None:
"""
Expand Down Expand Up @@ -127,12 +124,9 @@ class ListState:
.. versionadded:: 4.0.0
"""

def __init__(
self, list_state_client: ListStateClient, state_name: str, schema: Union[StructType, str]
) -> None:
def __init__(self, list_state_client: ListStateClient, state_name: str) -> None:
self._list_state_client = list_state_client
self._state_name = state_name
self.schema = schema

def exists(self) -> bool:
"""
Expand All @@ -150,19 +144,19 @@ def put(self, new_state: List[Tuple]) -> None:
"""
Update the values of the list state.
"""
self._list_state_client.put(self._state_name, self.schema, new_state)
self._list_state_client.put(self._state_name, new_state)

def append_value(self, new_state: Tuple) -> None:
"""
Append a new value to the list state.
"""
self._list_state_client.append_value(self._state_name, self.schema, new_state)
self._list_state_client.append_value(self._state_name, new_state)

def append_list(self, new_state: List[Tuple]) -> None:
"""
Append a list of new values to the list state.
"""
self._list_state_client.append_list(self._state_name, self.schema, new_state)
self._list_state_client.append_list(self._state_name, new_state)

def clear(self) -> None:
"""
Expand Down Expand Up @@ -275,7 +269,7 @@ def getValueState(
If ttl is not specified the state will never expire.
"""
self.stateful_processor_api_client.get_value_state(state_name, schema, ttl_duration_ms)
return ValueState(ValueStateClient(self.stateful_processor_api_client), state_name, schema)
return ValueState(ValueStateClient(self.stateful_processor_api_client, schema), state_name)

def getListState(
self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int] = None
Expand All @@ -299,7 +293,7 @@ def getListState(
If ttl is not specified the state will never expire.
"""
self.stateful_processor_api_client.get_list_state(state_name, schema, ttl_duration_ms)
return ListState(ListStateClient(self.stateful_processor_api_client), state_name, schema)
return ListState(ListStateClient(self.stateful_processor_api_client, schema), state_name)

def getMapState(
self,
Expand Down
Loading
Loading