From 24b8be6f1f86568342ff2d80c93af5407b01b9da Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 18 Mar 2025 15:23:49 +0000 Subject: [PATCH 1/4] Use a custom future --- python/restate/context.py | 35 ++++++---- python/restate/server_context.py | 115 ++++++++++++++++--------------- 2 files changed, 82 insertions(+), 68 deletions(-) diff --git a/python/restate/context.py b/python/restate/context.py index 6e7624a..9fafde0 100644 --- a/python/restate/context.py +++ b/python/restate/context.py @@ -27,6 +27,17 @@ RunAction = Union[Callable[[], T], Callable[[], Awaitable[T]]] +class RestateDurableFuture(typing.Generic[T], Awaitable[T]): + """ + Represents a durable future. + """ + + @abc.abstractmethod + def __await__(self): + pass + + + @dataclass class Request: """ @@ -57,13 +68,13 @@ class KeyValueStore(abc.ABC): @abc.abstractmethod def get(self, name: str, - serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[Any]]: + serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[Optional[Any]]: """ Retrieves the value associated with the given name. """ @abc.abstractmethod - def state_keys(self) -> Awaitable[List[str]]: + def state_keys(self) -> RestateDurableFuture[List[str]]: """Returns the list of keys in the store.""" @abc.abstractmethod @@ -110,7 +121,7 @@ def run(self, action: RunAction[T], serde: Serde[T] = JsonSerde(), max_attempts: typing.Optional[int] = None, - max_retry_duration: typing.Optional[timedelta] = None) -> Awaitable[T]: + max_retry_duration: typing.Optional[timedelta] = None) -> RestateDurableFuture[T]: """ Runs the given action with the given name. @@ -126,7 +137,7 @@ def run(self, """ @abc.abstractmethod - def sleep(self, delta: timedelta) -> Awaitable[None]: + def sleep(self, delta: timedelta) -> RestateDurableFuture[None]: """ Suspends the current invocation for the given duration """ @@ -135,7 +146,7 @@ def sleep(self, delta: timedelta) -> Awaitable[None]: def service_call(self, tpe: Callable[[Any, I], Awaitable[O]], arg: I, - idempotency_key: str | None = None) -> Awaitable[O]: + idempotency_key: str | None = None) -> RestateDurableFuture[O]: """ Invokes the given service with the given argument. """ @@ -158,7 +169,7 @@ def object_call(self, key: str, arg: I, idempotency_key: str | None = None, - ) -> Awaitable[O]: + ) -> RestateDurableFuture[O]: """ Invokes the given object with the given argument. """ @@ -181,7 +192,7 @@ def workflow_call(self, key: str, arg: I, idempotency_key: str | None = None, - ) -> Awaitable[O]: + ) -> RestateDurableFuture[O]: """ Invokes the given workflow with the given argument. """ @@ -205,7 +216,7 @@ def generic_call(self, handler: str, arg: bytes, key: Optional[str] = None, - idempotency_key: str | None = None) -> Awaitable[bytes]: + idempotency_key: str | None = None) -> RestateDurableFuture[bytes]: """ Invokes the given generic service/handler with the given argument. """ @@ -225,7 +236,7 @@ def generic_send(self, @abc.abstractmethod def awakeable(self, - serde: Serde[T] = JsonSerde()) -> typing.Tuple[str, Awaitable[Any]]: + serde: Serde[T] = JsonSerde()) -> typing.Tuple[str, RestateDurableFuture[Any]]: """ Returns the name of the awakeable and the future to be awaited. """ @@ -252,7 +263,7 @@ def cancel(self, invocation_id: str): """ @abc.abstractmethod - def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> T: + def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[T]: """ Attaches the invocation with the given id. """ @@ -282,13 +293,13 @@ def key(self) -> str: @abc.abstractmethod def get(self, name: str, - serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[Any]]: + serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[Optional[Any]]: """ Retrieves the value associated with the given name. """ @abc.abstractmethod - def state_keys(self) -> Awaitable[List[str]]: + def state_keys(self) -> RestateDurableFuture[List[str]]: """ Returns the list of keys in the store. """ diff --git a/python/restate/server_context.py b/python/restate/server_context.py index 94d059c..fe76e78 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -18,7 +18,7 @@ import typing import traceback -from restate.context import DurablePromise, ObjectContext, Request, SendHandle +from restate.context import DurablePromise, ObjectContext, Request, RestateDurableFuture, SendHandle from restate.exceptions import TerminalError from restate.handler import Handler, handler_from_callable, invoke_handler from restate.serde import BytesSerde, JsonSerde, Serde @@ -39,6 +39,28 @@ # disable too few public methods # pylint: disable=R0903 + +class ServerDurableFuture(RestateDurableFuture[T]): + """This class implements a durable future API""" + value: T | None = None + metadata: Dict[str, Any] | None = None + + def __init__(self, handle: int, factory) -> None: + super().__init__() + self.factory = factory + self.handle = handle + + def with_metadata(self, **metadata) -> 'ServerDurableFuture': + """Add metadata to the future.""" + self.metadata = metadata + return self + + def __await__(self): + print("..........Awaiting............", flush=True) + task = asyncio.create_task(self.factory()) + return task.__await__() + + class ServerSendHandle(SendHandle): """This class implements the send API""" _invocation_id: typing.Optional[str] @@ -57,6 +79,8 @@ async def invocation_id(self) -> str: self._invocation_id = res return res + + async def async_value(n: Callable[[], T]) -> T: """convert a simple value to a coroutine.""" return n() @@ -210,6 +234,7 @@ def must_take_notification(self, handle): raise TerminalError(res.message, res.code) return res + async def create_poll_or_cancel_coroutine(self, handle) -> bytes | None: """Create a coroutine to poll the handle.""" await self.take_and_send_output() @@ -240,20 +265,25 @@ async def create_poll_or_cancel_coroutine(self, handle) -> bytes | None: await self.take_and_send_output() - def get(self, name: str, serde: Serde[T] = JsonSerde()) -> typing.Awaitable[Optional[Any]]: - coro = self.create_poll_or_cancel_coroutine(self.vm.sys_get_state(name)) + def create_df(self, handle: int, serde: Serde[T] | None = None) -> ServerDurableFuture[T]: + """Create a durable future.""" - async def await_point(): - """Wait for this handle to be resolved.""" - res = await coro - if res is None: - return None + async def transform(): + res = await self.create_poll_or_cancel_coroutine(handle) + if res is None or serde is None: + return res return serde.deserialize(res) - return await_point() # do not await here, the caller will do it. + return ServerDurableFuture(handle, lambda : transform()) - def state_keys(self) -> Awaitable[List[str]]: - return self.create_poll_or_cancel_coroutine(self.vm.sys_get_state_keys()) # type: ignore + + + def get(self, name: str, serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[Optional[T]]: + handle = self.vm.sys_get_state(name) + return self.create_df(handle, serde) # type: ignore + + def state_keys(self) -> RestateDurableFuture[List[str]]: + return self.create_df(self.vm.sys_get_state_keys()) # type: ignore def set(self, name: str, value: T, serde: Serde[T] = JsonSerde()) -> None: """Set the value associated with the given name.""" @@ -309,29 +339,22 @@ def run(self, action: Callable[[], T] | Callable[[], Awaitable[T]], serde: Optional[Serde[T]] = JsonSerde(), max_attempts: Optional[int] = None, - max_retry_duration: Optional[timedelta] = None) -> Awaitable[T]: + max_retry_duration: Optional[timedelta] = None) -> RestateDurableFuture[T]: assert serde is not None handle = self.vm.sys_run(name) # Register closure to run + # TODO: use thunk to avoid coro leak warning. self.run_coros_to_execute[handle] = self.create_run_coroutine(handle, action, serde, max_attempts, max_retry_duration) # Prepare response coroutine - coro = self.create_poll_or_cancel_coroutine(handle) - async def await_point(): - """Wait for this handle to be resolved.""" - res = await coro - if res is None: - return None - return serde.deserialize(res) - - return await_point() # do not await here, the caller will do it. + return self.create_df(handle, serde) # type: ignore - def sleep(self, delta: timedelta) -> Awaitable[None]: + def sleep(self, delta: timedelta) -> RestateDurableFuture[None]: # convert timedelta to milliseconds millis = int(delta.total_seconds() * 1000) - return self.create_poll_or_cancel_coroutine(self.vm.sys_sleep(millis)) # type: ignore + return self.create_df(self.vm.sys_sleep(millis)) # type: ignore def do_call(self, tpe: Callable[[Any, I], Awaitable[O]], @@ -341,7 +364,7 @@ def do_call(self, send: bool = False, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None - ) -> Awaitable[O] | SendHandle: + ) -> RestateDurableFuture[O] | SendHandle: """Make an RPC call to the given handler""" target_handler = handler_from_callable(tpe) service=target_handler.service_tag.name @@ -362,7 +385,7 @@ def do_raw_call(self, send: bool = False, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None - ) -> Awaitable[O] | SendHandle: + ) -> RestateDurableFuture[O] | SendHandle: """Make an RPC call to the given handler""" parameter = input_serde.serialize(input_param) if send_delay: @@ -380,19 +403,16 @@ def do_raw_call(self, idempotency_key=idempotency_key, headers=headers) - async def await_point(s: ServerInvocationContext, h, o: Serde[O]): - """Wait for this handle to be resolved, and deserialize the response.""" - res = await s.create_poll_or_cancel_coroutine(h) - return o.deserialize(res) # type: ignore - - return await_point(self, handle.result_handle, output_serde) + # TODO: specialize this future for calls! + return self.create_df(handle=handle.result_handle, serde=output_serde).with_metadata(invocation_id=handle.invocation_id_handle) + def service_call(self, tpe: Callable[[Any, I], Awaitable[O]], arg: I, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None - ) -> Awaitable[O]: + ) -> RestateDurableFuture[O]: coro = self.do_call(tpe, arg, idempotency_key=idempotency_key, headers=headers) assert not isinstance(coro, SendHandle) return coro @@ -408,7 +428,7 @@ def object_call(self, arg: I, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None - ) -> Awaitable[O]: + ) -> RestateDurableFuture[O]: coro = self.do_call(tpe, arg, key, idempotency_key=idempotency_key, headers=headers) assert not isinstance(coro, SendHandle) return coro @@ -424,7 +444,7 @@ def workflow_call(self, arg: I, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None - ) -> Awaitable[O]: + ) -> RestateDurableFuture[O]: return self.object_call(tpe, key, arg, idempotency_key=idempotency_key, headers=headers) def workflow_send(self, tpe: Callable[[Any, I], Awaitable[O]], key: str, arg: I, send_delay: timedelta | None = None, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None) -> SendHandle: @@ -432,7 +452,7 @@ def workflow_send(self, tpe: Callable[[Any, I], Awaitable[O]], key: str, arg: I, assert isinstance(send, SendHandle) return send - def generic_call(self, service: str, handler: str, arg: bytes, key: str | None = None, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None) -> Awaitable[bytes]: + def generic_call(self, service: str, handler: str, arg: bytes, key: str | None = None, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None) -> RestateDurableFuture[bytes]: serde = BytesSerde() call_handle = self.do_raw_call(service=service, handler=handler, @@ -461,19 +481,10 @@ def generic_send(self, service: str, handler: str, arg: bytes, key: str | None = return send_handle def awakeable(self, - serde: typing.Optional[Serde[I]] = JsonSerde()) -> typing.Tuple[str, Awaitable[Any]]: + serde: typing.Optional[Serde[I]] = JsonSerde()) -> typing.Tuple[str, RestateDurableFuture[Any]]: assert serde is not None name, handle = self.vm.sys_awakeable() - coro = self.create_poll_or_cancel_coroutine(handle) - - async def await_point(): - """Wait for this handle to be resolved.""" - res = await coro - assert res is not None - return serde.deserialize(res) - - - return name, await_point() + return name, self.create_df(handle, serde) def resolve_awakeable(self, name: str, @@ -499,17 +510,9 @@ def cancel(self, invocation_id: str): raise ValueError("invocation_id cannot be None") self.vm.sys_cancel(invocation_id) - def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> T: + def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[T]: if invocation_id is None: raise ValueError("invocation_id cannot be None") assert serde is not None handle = self.vm.attach_invocation(invocation_id) - coro = self.create_poll_or_cancel_coroutine(handle) - - async def await_point(): - """Wait for this handle to be resolved.""" - res = await coro - assert res is not None - return serde.deserialize(res) - - return await_point() + return self.create_df(handle, serde) \ No newline at end of file From ee8f015bbe1be4a0078476f93a6131206b64960d Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 18 Mar 2025 16:30:41 +0100 Subject: [PATCH 2/4] lint --- python/restate/server_context.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/restate/server_context.py b/python/restate/server_context.py index fe76e78..dea8ca8 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -39,6 +39,8 @@ # disable too few public methods # pylint: disable=R0903 +# pylint: disable=W0511 + class ServerDurableFuture(RestateDurableFuture[T]): """This class implements a durable future API""" @@ -274,7 +276,7 @@ async def transform(): return res return serde.deserialize(res) - return ServerDurableFuture(handle, lambda : transform()) + return ServerDurableFuture(handle, transform) @@ -406,7 +408,6 @@ def do_raw_call(self, # TODO: specialize this future for calls! return self.create_df(handle=handle.result_handle, serde=output_serde).with_metadata(invocation_id=handle.invocation_id_handle) - def service_call(self, tpe: Callable[[Any, I], Awaitable[O]], arg: I, @@ -515,4 +516,4 @@ def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) - raise ValueError("invocation_id cannot be None") assert serde is not None handle = self.vm.attach_invocation(invocation_id) - return self.create_df(handle, serde) \ No newline at end of file + return self.create_df(handle, serde) From 9ecac83e5dc30c61edc01d519b95ab3b30e869a1 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 18 Mar 2025 16:36:08 +0100 Subject: [PATCH 3/4] Use Awaitable for few ctx methods --- python/restate/context.py | 4 ++-- python/restate/server_context.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/restate/context.py b/python/restate/context.py index 9fafde0..2e67d42 100644 --- a/python/restate/context.py +++ b/python/restate/context.py @@ -68,13 +68,13 @@ class KeyValueStore(abc.ABC): @abc.abstractmethod def get(self, name: str, - serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[Optional[Any]]: + serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[Any]]: """ Retrieves the value associated with the given name. """ @abc.abstractmethod - def state_keys(self) -> RestateDurableFuture[List[str]]: + def state_keys(self) -> Awaitable[List[str]]: """Returns the list of keys in the store.""" @abc.abstractmethod diff --git a/python/restate/server_context.py b/python/restate/server_context.py index dea8ca8..c2219d6 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -280,11 +280,11 @@ async def transform(): - def get(self, name: str, serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[Optional[T]]: + def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]]: handle = self.vm.sys_get_state(name) return self.create_df(handle, serde) # type: ignore - def state_keys(self) -> RestateDurableFuture[List[str]]: + def state_keys(self) -> Awaitable[List[str]]: return self.create_df(self.vm.sys_get_state_keys()) # type: ignore def set(self, name: str, value: T, serde: Serde[T] = JsonSerde()) -> None: From c99fcfc242f8788b96389e44d60007c6146660f1 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 18 Mar 2025 15:39:34 +0000 Subject: [PATCH 4/4] Lint --- python/restate/context.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/restate/context.py b/python/restate/context.py index 2e67d42..01ae648 100644 --- a/python/restate/context.py +++ b/python/restate/context.py @@ -27,6 +27,7 @@ RunAction = Union[Callable[[], T], Callable[[], Awaitable[T]]] +# pylint: disable=R0903 class RestateDurableFuture(typing.Generic[T], Awaitable[T]): """ Represents a durable future.