diff --git a/python/restate/context.py b/python/restate/context.py index 82a6060..7157f00 100644 --- a/python/restate/context.py +++ b/python/restate/context.py @@ -18,7 +18,7 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union import typing from datetime import timedelta -from restate.serde import DefaultSerde, JsonSerde, Serde +from restate.serde import DefaultSerde, Serde T = TypeVar('T') I = TypeVar('I') @@ -92,9 +92,17 @@ class KeyValueStore(abc.ABC): @abc.abstractmethod def get(self, name: str, - serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[Any]]: + serde: Serde[T] = DefaultSerde(), + type_hint: Optional[typing.Type[T]] = None + ) -> Awaitable[Optional[Any]]: """ Retrieves the value associated with the given name. + + Args: + name: The state name + serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type. + See also 'type_hint'. + type_hint: The type hint of the return value. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer. """ @abc.abstractmethod @@ -105,7 +113,7 @@ def state_keys(self) -> Awaitable[List[str]]: def set(self, name: str, value: T, - serde: Serde[T] = JsonSerde()) -> None: + serde: Serde[T] = DefaultSerde()) -> None: """set the value associated with the given name.""" @abc.abstractmethod @@ -266,7 +274,9 @@ def generic_send(self, @abc.abstractmethod def awakeable(self, - serde: Serde[T] = JsonSerde()) -> typing.Tuple[str, RestateDurableFuture[Any]]: + serde: Serde[T] = DefaultSerde(), + type_hint: Optional[typing.Type[T]] = None + ) -> typing.Tuple[str, RestateDurableFuture[Any]]: """ Returns the name of the awakeable and the future to be awaited. """ @@ -275,7 +285,7 @@ def awakeable(self, def resolve_awakeable(self, name: str, value: I, - serde: Serde[I] = JsonSerde()) -> None: + serde: Serde[I] = DefaultSerde()) -> None: """ Resolves the awakeable with the given name. """ @@ -293,7 +303,9 @@ def cancel(self, invocation_id: str): """ @abc.abstractmethod - def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[T]: + def attach_invocation(self, invocation_id: str, serde: Serde[T] = DefaultSerde(), + type_hint: typing.Optional[typing.Type[T]] = None + ) -> RestateDurableFuture[T]: """ Attaches the invocation with the given id. """ @@ -323,9 +335,17 @@ def key(self) -> str: @abc.abstractmethod def get(self, name: str, - serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[Optional[Any]]: + serde: Serde[T] = DefaultSerde(), + type_hint: Optional[typing.Type[T]] = None + ) -> RestateDurableFuture[Optional[Any]]: """ Retrieves the value associated with the given name. + + Args: + name: The state name + serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type. + See also 'type_hint'. + type_hint: The type hint of the return value. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer. """ @abc.abstractmethod @@ -339,7 +359,7 @@ class DurablePromise(typing.Generic[T]): Represents a durable promise. """ - def __init__(self, name: str, serde: Serde[T] = JsonSerde()) -> None: + def __init__(self, name: str, serde: Serde[T] = DefaultSerde()) -> None: self.name = name self.serde = serde @@ -373,7 +393,7 @@ class WorkflowContext(ObjectContext): """ @abc.abstractmethod - def promise(self, name: str, serde: Serde[T] = JsonSerde()) -> DurablePromise[Any]: + def promise(self, name: str, serde: Serde[T] = DefaultSerde()) -> DurablePromise[Any]: """ Returns a durable promise with the given name. """ @@ -384,7 +404,7 @@ class WorkflowSharedContext(ObjectSharedContext): """ @abc.abstractmethod - def promise(self, name: str, serde: Serde[T] = JsonSerde()) -> DurablePromise[Any]: + def promise(self, name: str, serde: Serde[T] = DefaultSerde()) -> DurablePromise[Any]: """ Returns a durable promise with the given name. """ diff --git a/python/restate/serde.py b/python/restate/serde.py index ddce229..bbba15c 100644 --- a/python/restate/serde.py +++ b/python/restate/serde.py @@ -145,6 +145,17 @@ class DefaultSerde(Serde[I]): while allowing automatic serde selection based on type hints. """ + def __init__(self, type_hint: typing.Optional[typing.Type[I]] = None): + super().__init__() + self.type_hint = type_hint + + def with_maybe_type(self, type_hint: typing.Type[I] | None = None) -> "DefaultSerde[I]": + """ + Sets the type hint for the serde. + """ + self.type_hint = type_hint + return self + def deserialize(self, buf: bytes) -> typing.Optional[I]: """ Deserializes a byte array into a Python object. @@ -157,6 +168,8 @@ def deserialize(self, buf: bytes) -> typing.Optional[I]: """ if not buf: return None + if is_pydantic(self.type_hint): + return self.type_hint.model_validate_json(buf) # type: ignore return json.loads(buf) def serialize(self, obj: typing.Optional[I]) -> bytes: @@ -174,11 +187,9 @@ def serialize(self, obj: typing.Optional[I]) -> bytes: if obj is None: return bytes() - if isinstance(obj, PydanticBaseModel): - # Use the Pydantic-specific serialization + if is_pydantic(self.type_hint): return obj.model_dump_json().encode("utf-8") # type: ignore[attr-defined] - # Fallback to standard JSON serialization return json.dumps(obj).encode("utf-8") @@ -218,22 +229,3 @@ def serialize(self, obj: typing.Optional[I]) -> bytes: return bytes() json_str = obj.model_dump_json() # type: ignore[attr-defined] return json_str.encode("utf-8") - - -def for_type(type_hint: typing.Type[T]) -> Serde[T]: - """ - Automatically selects a serde based on the type hint. - - Args: - type_hint (typing.Type[T]): The type hint to use for serde selection. - - Returns: - Serde[T]: The serde to use for the given type hint. - """ - if is_pydantic(type_hint): - return PydanticJsonSerde(type_hint) - if isinstance(type_hint, bytes): - return BytesSerde() - if isinstance(type_hint, (dict, list, int, float, str, bool)): - return JsonSerde() - return DefaultSerde() diff --git a/python/restate/server_context.py b/python/restate/server_context.py index 90c10a0..f349eab 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -26,7 +26,7 @@ from restate.context import DurablePromise, ObjectContext, Request, RestateDurableCallFuture, RestateDurableFuture, SendHandle, RestateDurableSleepFuture from restate.exceptions import TerminalError from restate.handler import Handler, handler_from_callable, invoke_handler -from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde, for_type +from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde from restate.server_types import Receive, Send from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig # pylint: disable=line-too-long from restate.vm import DoProgressAnyCompleted, DoProgressCancelSignalReceived, DoProgressReadFromInput, DoProgressExecuteRun, DoWaitPendingRun @@ -130,7 +130,7 @@ class ServerDurablePromise(DurablePromise): """This class implements a durable promise API""" def __init__(self, server_context: "ServerInvocationContext", name, serde) -> None: - super().__init__(name=name, serde=JsonSerde() if serde is None else serde) + super().__init__(name=name, serde=DefaultSerde() if serde is None else serde) self.server_context = server_context def value(self) -> RestateDurableFuture[Any]: @@ -359,15 +359,22 @@ async def inv_id_factory(): return ServerCallDurableFuture(self, handle, self._create_fetch_result_coroutine(handle, serde), inv_id_factory) - def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]]: + def get(self, name: str, + serde: Serde[T] = DefaultSerde(), + type_hint: Optional[typing.Type[T]] = None + ) -> Awaitable[Optional[T]]: handle = self.vm.sys_get_state(name) + if isinstance(serde, DefaultSerde): + serde = serde.with_maybe_type(type_hint) return self.create_future(handle, serde) # type: ignore def state_keys(self) -> Awaitable[List[str]]: return self.create_future(self.vm.sys_get_state_keys()) - def set(self, name: str, value: T, serde: Serde[T] = JsonSerde()) -> None: + def set(self, name: str, value: T, serde: Serde[T] = DefaultSerde()) -> None: """Set the value associated with the given name.""" + if isinstance(serde, DefaultSerde): + serde = serde.with_maybe_type(type(value)) buffer = serde.serialize(value) self.vm.sys_set_state(name, bytes(buffer)) @@ -423,12 +430,11 @@ def run(self, type_hint: Optional[typing.Type[T]] = None ) -> RestateDurableFuture[T]: - if type_hint is not None: - serde = for_type(type_hint) - elif isinstance(serde, DefaultSerde): - signature = inspect.signature(action, eval_str=True) - serde = for_type(signature.return_annotation) - + if isinstance(serde, DefaultSerde): + if type_hint is None: + signature = inspect.signature(action, eval_str=True) + type_hint = signature.return_annotation + serde = serde.with_maybe_type(type_hint) handle = self.vm.sys_run(name) self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, action, serde, max_attempts, max_retry_duration) return self.create_future(handle, serde) # type: ignore @@ -564,16 +570,20 @@ def generic_send(self, service: str, handler: str, arg: bytes, key: str | None = return send_handle def awakeable(self, - serde: typing.Optional[Serde[I]] = JsonSerde()) -> typing.Tuple[str, RestateDurableFuture[Any]]: - assert serde is not None + serde: Serde[I] = DefaultSerde(), + type_hint: Optional[typing.Type[I]] = None + ) -> typing.Tuple[str, RestateDurableFuture[Any]]: + if isinstance(serde, DefaultSerde): + serde = serde.with_maybe_type(type_hint) name, handle = self.vm.sys_awakeable() return name, self.create_future(handle, serde) def resolve_awakeable(self, name: str, value: I, - serde: typing.Optional[Serde[I]] = JsonSerde()) -> None: - assert serde is not None + serde: Serde[I] = DefaultSerde()) -> None: + if isinstance(serde, DefaultSerde): + serde = serde.with_maybe_type(type(value)) buf = serde.serialize(value) self.vm.sys_resolve_awakeable(name, buf) @@ -593,9 +603,12 @@ def cancel(self, invocation_id: str): raise ValueError("invocation_id cannot be None") self.vm.sys_cancel(invocation_id) - def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[T]: + def attach_invocation(self, invocation_id: str, serde: Serde[T] = DefaultSerde(), + type_hint: Optional[typing.Type[T]] = None + ) -> RestateDurableFuture[T]: if invocation_id is None: raise ValueError("invocation_id cannot be None") - assert serde is not None + if isinstance(serde, DefaultSerde): + serde = serde.with_maybe_type(type_hint) handle = self.vm.attach_invocation(invocation_id) return self.create_future(handle, serde)