From 216bb4832124b228ca8987b2323df4c5d48bbe00 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 25 Mar 2025 13:40:10 +0100 Subject: [PATCH 1/3] Add type hints in more places --- python/restate/context.py | 28 +++++++++++++------- python/restate/serde.py | 36 ++++++++++--------------- python/restate/server_context.py | 45 ++++++++++++++++++++------------ 3 files changed, 61 insertions(+), 48 deletions(-) diff --git a/python/restate/context.py b/python/restate/context.py index 82a6060..d38f3d7 100644 --- a/python/restate/context.py +++ b/python/restate/context.py @@ -18,7 +18,7 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union import typing from datetime import timedelta -from restate.serde import DefaultSerde, JsonSerde, Serde +from restate.serde import DefaultSerde, Serde T = TypeVar('T') I = TypeVar('I') @@ -92,7 +92,9 @@ class KeyValueStore(abc.ABC): @abc.abstractmethod def get(self, name: str, - serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[Any]]: + serde: Serde[T] = DefaultSerde(), + type_hint: Optional[typing.Type[T]] = None + ) -> Awaitable[Optional[Any]]: """ Retrieves the value associated with the given name. """ @@ -105,7 +107,7 @@ def state_keys(self) -> Awaitable[List[str]]: def set(self, name: str, value: T, - serde: Serde[T] = JsonSerde()) -> None: + serde: Serde[T] = DefaultSerde()) -> None: """set the value associated with the given name.""" @abc.abstractmethod @@ -266,7 +268,9 @@ def generic_send(self, @abc.abstractmethod def awakeable(self, - serde: Serde[T] = JsonSerde()) -> typing.Tuple[str, RestateDurableFuture[Any]]: + serde: Serde[T] = DefaultSerde(), + type_hint: Optional[typing.Type[T]] = None + ) -> typing.Tuple[str, RestateDurableFuture[Any]]: """ Returns the name of the awakeable and the future to be awaited. """ @@ -275,7 +279,7 @@ def awakeable(self, def resolve_awakeable(self, name: str, value: I, - serde: Serde[I] = JsonSerde()) -> None: + serde: Serde[I] = DefaultSerde()) -> None: """ Resolves the awakeable with the given name. """ @@ -293,7 +297,9 @@ def cancel(self, invocation_id: str): """ @abc.abstractmethod - def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[T]: + def attach_invocation(self, invocation_id: str, serde: Serde[T] = DefaultSerde(), + type_hint: typing.Optional[typing.Type[T]] = None + ) -> RestateDurableFuture[T]: """ Attaches the invocation with the given id. """ @@ -323,7 +329,9 @@ def key(self) -> str: @abc.abstractmethod def get(self, name: str, - serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[Optional[Any]]: + serde: Serde[T] = DefaultSerde(), + type_hint: Optional[typing.Type[T]] = None + ) -> RestateDurableFuture[Optional[Any]]: """ Retrieves the value associated with the given name. """ @@ -339,7 +347,7 @@ class DurablePromise(typing.Generic[T]): Represents a durable promise. """ - def __init__(self, name: str, serde: Serde[T] = JsonSerde()) -> None: + def __init__(self, name: str, serde: Serde[T] = DefaultSerde()) -> None: self.name = name self.serde = serde @@ -373,7 +381,7 @@ class WorkflowContext(ObjectContext): """ @abc.abstractmethod - def promise(self, name: str, serde: Serde[T] = JsonSerde()) -> DurablePromise[Any]: + def promise(self, name: str, serde: Serde[T] = DefaultSerde()) -> DurablePromise[Any]: """ Returns a durable promise with the given name. """ @@ -384,7 +392,7 @@ class WorkflowSharedContext(ObjectSharedContext): """ @abc.abstractmethod - def promise(self, name: str, serde: Serde[T] = JsonSerde()) -> DurablePromise[Any]: + def promise(self, name: str, serde: Serde[T] = DefaultSerde()) -> DurablePromise[Any]: """ Returns a durable promise with the given name. """ diff --git a/python/restate/serde.py b/python/restate/serde.py index ddce229..bbba15c 100644 --- a/python/restate/serde.py +++ b/python/restate/serde.py @@ -145,6 +145,17 @@ class DefaultSerde(Serde[I]): while allowing automatic serde selection based on type hints. """ + def __init__(self, type_hint: typing.Optional[typing.Type[I]] = None): + super().__init__() + self.type_hint = type_hint + + def with_maybe_type(self, type_hint: typing.Type[I] | None = None) -> "DefaultSerde[I]": + """ + Sets the type hint for the serde. + """ + self.type_hint = type_hint + return self + def deserialize(self, buf: bytes) -> typing.Optional[I]: """ Deserializes a byte array into a Python object. @@ -157,6 +168,8 @@ def deserialize(self, buf: bytes) -> typing.Optional[I]: """ if not buf: return None + if is_pydantic(self.type_hint): + return self.type_hint.model_validate_json(buf) # type: ignore return json.loads(buf) def serialize(self, obj: typing.Optional[I]) -> bytes: @@ -174,11 +187,9 @@ def serialize(self, obj: typing.Optional[I]) -> bytes: if obj is None: return bytes() - if isinstance(obj, PydanticBaseModel): - # Use the Pydantic-specific serialization + if is_pydantic(self.type_hint): return obj.model_dump_json().encode("utf-8") # type: ignore[attr-defined] - # Fallback to standard JSON serialization return json.dumps(obj).encode("utf-8") @@ -218,22 +229,3 @@ def serialize(self, obj: typing.Optional[I]) -> bytes: return bytes() json_str = obj.model_dump_json() # type: ignore[attr-defined] return json_str.encode("utf-8") - - -def for_type(type_hint: typing.Type[T]) -> Serde[T]: - """ - Automatically selects a serde based on the type hint. - - Args: - type_hint (typing.Type[T]): The type hint to use for serde selection. - - Returns: - Serde[T]: The serde to use for the given type hint. - """ - if is_pydantic(type_hint): - return PydanticJsonSerde(type_hint) - if isinstance(type_hint, bytes): - return BytesSerde() - if isinstance(type_hint, (dict, list, int, float, str, bool)): - return JsonSerde() - return DefaultSerde() diff --git a/python/restate/server_context.py b/python/restate/server_context.py index 90c10a0..f349eab 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -26,7 +26,7 @@ from restate.context import DurablePromise, ObjectContext, Request, RestateDurableCallFuture, RestateDurableFuture, SendHandle, RestateDurableSleepFuture from restate.exceptions import TerminalError from restate.handler import Handler, handler_from_callable, invoke_handler -from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde, for_type +from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde from restate.server_types import Receive, Send from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig # pylint: disable=line-too-long from restate.vm import DoProgressAnyCompleted, DoProgressCancelSignalReceived, DoProgressReadFromInput, DoProgressExecuteRun, DoWaitPendingRun @@ -130,7 +130,7 @@ class ServerDurablePromise(DurablePromise): """This class implements a durable promise API""" def __init__(self, server_context: "ServerInvocationContext", name, serde) -> None: - super().__init__(name=name, serde=JsonSerde() if serde is None else serde) + super().__init__(name=name, serde=DefaultSerde() if serde is None else serde) self.server_context = server_context def value(self) -> RestateDurableFuture[Any]: @@ -359,15 +359,22 @@ async def inv_id_factory(): return ServerCallDurableFuture(self, handle, self._create_fetch_result_coroutine(handle, serde), inv_id_factory) - def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]]: + def get(self, name: str, + serde: Serde[T] = DefaultSerde(), + type_hint: Optional[typing.Type[T]] = None + ) -> Awaitable[Optional[T]]: handle = self.vm.sys_get_state(name) + if isinstance(serde, DefaultSerde): + serde = serde.with_maybe_type(type_hint) return self.create_future(handle, serde) # type: ignore def state_keys(self) -> Awaitable[List[str]]: return self.create_future(self.vm.sys_get_state_keys()) - def set(self, name: str, value: T, serde: Serde[T] = JsonSerde()) -> None: + def set(self, name: str, value: T, serde: Serde[T] = DefaultSerde()) -> None: """Set the value associated with the given name.""" + if isinstance(serde, DefaultSerde): + serde = serde.with_maybe_type(type(value)) buffer = serde.serialize(value) self.vm.sys_set_state(name, bytes(buffer)) @@ -423,12 +430,11 @@ def run(self, type_hint: Optional[typing.Type[T]] = None ) -> RestateDurableFuture[T]: - if type_hint is not None: - serde = for_type(type_hint) - elif isinstance(serde, DefaultSerde): - signature = inspect.signature(action, eval_str=True) - serde = for_type(signature.return_annotation) - + if isinstance(serde, DefaultSerde): + if type_hint is None: + signature = inspect.signature(action, eval_str=True) + type_hint = signature.return_annotation + serde = serde.with_maybe_type(type_hint) handle = self.vm.sys_run(name) self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, action, serde, max_attempts, max_retry_duration) return self.create_future(handle, serde) # type: ignore @@ -564,16 +570,20 @@ def generic_send(self, service: str, handler: str, arg: bytes, key: str | None = return send_handle def awakeable(self, - serde: typing.Optional[Serde[I]] = JsonSerde()) -> typing.Tuple[str, RestateDurableFuture[Any]]: - assert serde is not None + serde: Serde[I] = DefaultSerde(), + type_hint: Optional[typing.Type[I]] = None + ) -> typing.Tuple[str, RestateDurableFuture[Any]]: + if isinstance(serde, DefaultSerde): + serde = serde.with_maybe_type(type_hint) name, handle = self.vm.sys_awakeable() return name, self.create_future(handle, serde) def resolve_awakeable(self, name: str, value: I, - serde: typing.Optional[Serde[I]] = JsonSerde()) -> None: - assert serde is not None + serde: Serde[I] = DefaultSerde()) -> None: + if isinstance(serde, DefaultSerde): + serde = serde.with_maybe_type(type(value)) buf = serde.serialize(value) self.vm.sys_resolve_awakeable(name, buf) @@ -593,9 +603,12 @@ def cancel(self, invocation_id: str): raise ValueError("invocation_id cannot be None") self.vm.sys_cancel(invocation_id) - def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[T]: + def attach_invocation(self, invocation_id: str, serde: Serde[T] = DefaultSerde(), + type_hint: Optional[typing.Type[T]] = None + ) -> RestateDurableFuture[T]: if invocation_id is None: raise ValueError("invocation_id cannot be None") - assert serde is not None + if isinstance(serde, DefaultSerde): + serde = serde.with_maybe_type(type_hint) handle = self.vm.attach_invocation(invocation_id) return self.create_future(handle, serde) From f4062451564e0ff1d7fe066fdc5cd1776bc918f6 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 25 Mar 2025 13:47:31 +0000 Subject: [PATCH 2/3] Add docstr for get methods --- examples/virtual_object.py | 3 ++- python/restate/context.py | 16 ++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/examples/virtual_object.py b/examples/virtual_object.py index 57e6484..74e91f8 100644 --- a/examples/virtual_object.py +++ b/examples/virtual_object.py @@ -25,4 +25,5 @@ async def increment(ctx: ObjectContext, value: int) -> int: @counter.handler(kind="shared") async def count(ctx: ObjectSharedContext) -> int: - return await ctx.get("counter") or 0 + n = await ctx.get("counter", type_hint=int) or 0 + return n diff --git a/python/restate/context.py b/python/restate/context.py index d38f3d7..be74ccc 100644 --- a/python/restate/context.py +++ b/python/restate/context.py @@ -94,9 +94,15 @@ def get(self, name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[typing.Type[T]] = None - ) -> Awaitable[Optional[Any]]: + ) -> Awaitable[Optional[T]]: """ Retrieves the value associated with the given name. + + Args: + name: The state name + serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type. + See also 'type_hint'. + type_hint: The type hint of the return value. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer. """ @abc.abstractmethod @@ -331,9 +337,15 @@ def get(self, name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[typing.Type[T]] = None - ) -> RestateDurableFuture[Optional[Any]]: + ) -> RestateDurableFuture[Optional[T]]: """ Retrieves the value associated with the given name. + + Args: + name: The state name + serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type. + See also 'type_hint'. + type_hint: The type hint of the return value. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer. """ @abc.abstractmethod From abb27ab416d0cd32f2bb983ff92301290e39c530 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 25 Mar 2025 14:36:54 +0000 Subject: [PATCH 3/3] Revert back T -> Any --- examples/virtual_object.py | 3 +-- python/restate/context.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/virtual_object.py b/examples/virtual_object.py index 74e91f8..57e6484 100644 --- a/examples/virtual_object.py +++ b/examples/virtual_object.py @@ -25,5 +25,4 @@ async def increment(ctx: ObjectContext, value: int) -> int: @counter.handler(kind="shared") async def count(ctx: ObjectSharedContext) -> int: - n = await ctx.get("counter", type_hint=int) or 0 - return n + return await ctx.get("counter") or 0 diff --git a/python/restate/context.py b/python/restate/context.py index be74ccc..7157f00 100644 --- a/python/restate/context.py +++ b/python/restate/context.py @@ -94,7 +94,7 @@ def get(self, name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[typing.Type[T]] = None - ) -> Awaitable[Optional[T]]: + ) -> Awaitable[Optional[Any]]: """ Retrieves the value associated with the given name. @@ -337,7 +337,7 @@ def get(self, name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[typing.Type[T]] = None - ) -> RestateDurableFuture[Optional[T]]: + ) -> RestateDurableFuture[Optional[Any]]: """ Retrieves the value associated with the given name.