Skip to content

Use DefaultSerde in more places #64

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions python/restate/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union
import typing
from datetime import timedelta
from restate.serde import DefaultSerde, JsonSerde, Serde
from restate.serde import DefaultSerde, Serde

T = TypeVar('T')
I = TypeVar('I')
Expand Down Expand Up @@ -92,9 +92,17 @@ class KeyValueStore(abc.ABC):
@abc.abstractmethod
def get(self,
name: str,
serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[Any]]:
serde: Serde[T] = DefaultSerde(),
type_hint: Optional[typing.Type[T]] = None
) -> Awaitable[Optional[Any]]:
"""
Retrieves the value associated with the given name.

Args:
name: The state name
serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type.
See also 'type_hint'.
type_hint: The type hint of the return value. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer.
"""

@abc.abstractmethod
Expand All @@ -105,7 +113,7 @@ def state_keys(self) -> Awaitable[List[str]]:
def set(self,
name: str,
value: T,
serde: Serde[T] = JsonSerde()) -> None:
serde: Serde[T] = DefaultSerde()) -> None:
"""set the value associated with the given name."""

@abc.abstractmethod
Expand Down Expand Up @@ -266,7 +274,9 @@ def generic_send(self,

@abc.abstractmethod
def awakeable(self,
serde: Serde[T] = JsonSerde()) -> typing.Tuple[str, RestateDurableFuture[Any]]:
serde: Serde[T] = DefaultSerde(),
type_hint: Optional[typing.Type[T]] = None
) -> typing.Tuple[str, RestateDurableFuture[Any]]:
"""
Returns the name of the awakeable and the future to be awaited.
"""
Expand All @@ -275,7 +285,7 @@ def awakeable(self,
def resolve_awakeable(self,
name: str,
value: I,
serde: Serde[I] = JsonSerde()) -> None:
serde: Serde[I] = DefaultSerde()) -> None:
"""
Resolves the awakeable with the given name.
"""
Expand All @@ -293,7 +303,9 @@ def cancel(self, invocation_id: str):
"""

@abc.abstractmethod
def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[T]:
def attach_invocation(self, invocation_id: str, serde: Serde[T] = DefaultSerde(),
type_hint: typing.Optional[typing.Type[T]] = None
) -> RestateDurableFuture[T]:
"""
Attaches the invocation with the given id.
"""
Expand Down Expand Up @@ -323,9 +335,17 @@ def key(self) -> str:
@abc.abstractmethod
def get(self,
name: str,
serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[Optional[Any]]:
serde: Serde[T] = DefaultSerde(),
type_hint: Optional[typing.Type[T]] = None
) -> RestateDurableFuture[Optional[Any]]:
"""
Retrieves the value associated with the given name.

Args:
name: The state name
serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type.
See also 'type_hint'.
type_hint: The type hint of the return value. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer.
"""

@abc.abstractmethod
Expand All @@ -339,7 +359,7 @@ class DurablePromise(typing.Generic[T]):
Represents a durable promise.
"""

def __init__(self, name: str, serde: Serde[T] = JsonSerde()) -> None:
def __init__(self, name: str, serde: Serde[T] = DefaultSerde()) -> None:
self.name = name
self.serde = serde

Expand Down Expand Up @@ -373,7 +393,7 @@ class WorkflowContext(ObjectContext):
"""

@abc.abstractmethod
def promise(self, name: str, serde: Serde[T] = JsonSerde()) -> DurablePromise[Any]:
def promise(self, name: str, serde: Serde[T] = DefaultSerde()) -> DurablePromise[Any]:
"""
Returns a durable promise with the given name.
"""
Expand All @@ -384,7 +404,7 @@ class WorkflowSharedContext(ObjectSharedContext):
"""

@abc.abstractmethod
def promise(self, name: str, serde: Serde[T] = JsonSerde()) -> DurablePromise[Any]:
def promise(self, name: str, serde: Serde[T] = DefaultSerde()) -> DurablePromise[Any]:
"""
Returns a durable promise with the given name.
"""
36 changes: 14 additions & 22 deletions python/restate/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,17 @@ class DefaultSerde(Serde[I]):
while allowing automatic serde selection based on type hints.
"""

def __init__(self, type_hint: typing.Optional[typing.Type[I]] = None):
super().__init__()
self.type_hint = type_hint

def with_maybe_type(self, type_hint: typing.Type[I] | None = None) -> "DefaultSerde[I]":
"""
Sets the type hint for the serde.
"""
self.type_hint = type_hint
return self

def deserialize(self, buf: bytes) -> typing.Optional[I]:
"""
Deserializes a byte array into a Python object.
Expand All @@ -157,6 +168,8 @@ def deserialize(self, buf: bytes) -> typing.Optional[I]:
"""
if not buf:
return None
if is_pydantic(self.type_hint):
return self.type_hint.model_validate_json(buf) # type: ignore
return json.loads(buf)

def serialize(self, obj: typing.Optional[I]) -> bytes:
Expand All @@ -174,11 +187,9 @@ def serialize(self, obj: typing.Optional[I]) -> bytes:
if obj is None:
return bytes()

if isinstance(obj, PydanticBaseModel):
# Use the Pydantic-specific serialization
if is_pydantic(self.type_hint):
return obj.model_dump_json().encode("utf-8") # type: ignore[attr-defined]

# Fallback to standard JSON serialization
return json.dumps(obj).encode("utf-8")


Expand Down Expand Up @@ -218,22 +229,3 @@ def serialize(self, obj: typing.Optional[I]) -> bytes:
return bytes()
json_str = obj.model_dump_json() # type: ignore[attr-defined]
return json_str.encode("utf-8")


def for_type(type_hint: typing.Type[T]) -> Serde[T]:
"""
Automatically selects a serde based on the type hint.

Args:
type_hint (typing.Type[T]): The type hint to use for serde selection.

Returns:
Serde[T]: The serde to use for the given type hint.
"""
if is_pydantic(type_hint):
return PydanticJsonSerde(type_hint)
if isinstance(type_hint, bytes):
return BytesSerde()
if isinstance(type_hint, (dict, list, int, float, str, bool)):
return JsonSerde()
return DefaultSerde()
45 changes: 29 additions & 16 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from restate.context import DurablePromise, ObjectContext, Request, RestateDurableCallFuture, RestateDurableFuture, SendHandle, RestateDurableSleepFuture
from restate.exceptions import TerminalError
from restate.handler import Handler, handler_from_callable, invoke_handler
from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde, for_type
from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde
from restate.server_types import Receive, Send
from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig # pylint: disable=line-too-long
from restate.vm import DoProgressAnyCompleted, DoProgressCancelSignalReceived, DoProgressReadFromInput, DoProgressExecuteRun, DoWaitPendingRun
Expand Down Expand Up @@ -130,7 +130,7 @@ class ServerDurablePromise(DurablePromise):
"""This class implements a durable promise API"""

def __init__(self, server_context: "ServerInvocationContext", name, serde) -> None:
super().__init__(name=name, serde=JsonSerde() if serde is None else serde)
super().__init__(name=name, serde=DefaultSerde() if serde is None else serde)
self.server_context = server_context

def value(self) -> RestateDurableFuture[Any]:
Expand Down Expand Up @@ -359,15 +359,22 @@ async def inv_id_factory():

return ServerCallDurableFuture(self, handle, self._create_fetch_result_coroutine(handle, serde), inv_id_factory)

def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]]:
def get(self, name: str,
serde: Serde[T] = DefaultSerde(),
type_hint: Optional[typing.Type[T]] = None
) -> Awaitable[Optional[T]]:
handle = self.vm.sys_get_state(name)
if isinstance(serde, DefaultSerde):
serde = serde.with_maybe_type(type_hint)
return self.create_future(handle, serde) # type: ignore

def state_keys(self) -> Awaitable[List[str]]:
return self.create_future(self.vm.sys_get_state_keys())

def set(self, name: str, value: T, serde: Serde[T] = JsonSerde()) -> None:
def set(self, name: str, value: T, serde: Serde[T] = DefaultSerde()) -> None:
"""Set the value associated with the given name."""
if isinstance(serde, DefaultSerde):
serde = serde.with_maybe_type(type(value))
buffer = serde.serialize(value)
self.vm.sys_set_state(name, bytes(buffer))

Expand Down Expand Up @@ -423,12 +430,11 @@ def run(self,
type_hint: Optional[typing.Type[T]] = None
) -> RestateDurableFuture[T]:

if type_hint is not None:
serde = for_type(type_hint)
elif isinstance(serde, DefaultSerde):
signature = inspect.signature(action, eval_str=True)
serde = for_type(signature.return_annotation)

if isinstance(serde, DefaultSerde):
if type_hint is None:
signature = inspect.signature(action, eval_str=True)
type_hint = signature.return_annotation
serde = serde.with_maybe_type(type_hint)
handle = self.vm.sys_run(name)
self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, action, serde, max_attempts, max_retry_duration)
return self.create_future(handle, serde) # type: ignore
Expand Down Expand Up @@ -564,16 +570,20 @@ def generic_send(self, service: str, handler: str, arg: bytes, key: str | None =
return send_handle

def awakeable(self,
serde: typing.Optional[Serde[I]] = JsonSerde()) -> typing.Tuple[str, RestateDurableFuture[Any]]:
assert serde is not None
serde: Serde[I] = DefaultSerde(),
type_hint: Optional[typing.Type[I]] = None
) -> typing.Tuple[str, RestateDurableFuture[Any]]:
if isinstance(serde, DefaultSerde):
serde = serde.with_maybe_type(type_hint)
name, handle = self.vm.sys_awakeable()
return name, self.create_future(handle, serde)

def resolve_awakeable(self,
name: str,
value: I,
serde: typing.Optional[Serde[I]] = JsonSerde()) -> None:
assert serde is not None
serde: Serde[I] = DefaultSerde()) -> None:
if isinstance(serde, DefaultSerde):
serde = serde.with_maybe_type(type(value))
buf = serde.serialize(value)
self.vm.sys_resolve_awakeable(name, buf)

Expand All @@ -593,9 +603,12 @@ def cancel(self, invocation_id: str):
raise ValueError("invocation_id cannot be None")
self.vm.sys_cancel(invocation_id)

def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[T]:
def attach_invocation(self, invocation_id: str, serde: Serde[T] = DefaultSerde(),
type_hint: Optional[typing.Type[T]] = None
) -> RestateDurableFuture[T]:
if invocation_id is None:
raise ValueError("invocation_id cannot be None")
assert serde is not None
if isinstance(serde, DefaultSerde):
serde = serde.with_maybe_type(type_hint)
handle = self.vm.attach_invocation(invocation_id)
return self.create_future(handle, serde)