diff --git a/python/restate/__init__.py b/python/restate/__init__.py index 69b014f..c23b62b 100644 --- a/python/restate/__init__.py +++ b/python/restate/__init__.py @@ -19,7 +19,9 @@ # types from .context import Context, ObjectContext, ObjectSharedContext from .context import WorkflowContext, WorkflowSharedContext -from .context import DurablePromise +from .context import DurablePromise, RestateDurableFuture, RestateDurableCallFuture, SendHandle +from .combinators import wait, gather, as_completed, ALL_COMPLETED, FIRST_COMPLETED +from .exceptions import TerminalError from .endpoint import app @@ -43,6 +45,15 @@ def test_harness(app, follow_logs = False, restate_image = ""): # type: ignore "WorkflowContext", "WorkflowSharedContext", "DurablePromise", + "RestateDurableFuture", + "RestateDurableCallFuture", + "SendHandle", + "TerminalError", "app", "test_harness", + "wait", + "gather", + "as_completed", + "ALL_COMPLETED", + "FIRST_COMPLETED", ] diff --git a/python/restate/combinators.py b/python/restate/combinators.py new file mode 100644 index 0000000..b56a6bb --- /dev/null +++ b/python/restate/combinators.py @@ -0,0 +1,119 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +# pylint: disable=R0913,C0301,R0917 +# pylint: disable=line-too-long +"""combines multiple futures into a single future""" + +from typing import Any, List, Tuple +from restate.exceptions import TerminalError +from restate.context import RestateDurableFuture +from restate.server_context import ServerDurableFuture, ServerInvocationContext + +FIRST_COMPLETED = 1 +ALL_COMPLETED = 2 + +async def wait(*futures: RestateDurableFuture[Any], mode: int = FIRST_COMPLETED) -> Tuple[List[RestateDurableFuture[Any]], List[RestateDurableFuture[Any]]]: + """ + Blocks until at least one of the futures/all of the futures are completed. + + Returns a tuple of two lists: the first list contains the futures that are completed, + the second list contains the futures that are not completed. + + The mode parameter can be either FIRST_COMPLETED or ALL_COMPLETED. + Using FIRST_COMPLETED will return as soon as one of the futures is completed. + Using ALL_COMPLETED will return only when all futures are completed. + + examples: + + completed, waiting = await wait(f1, f2, f3, mode=FIRST_COMPLETED) + for completed_future in completed: + # do something with the completed future + print(await completed_future) # prints the result of the future + + or + + completed, waiting = await wait(f1, f2, f3, mode=ALL_COMPLETED) + assert waiting == [] + + + """ + assert mode in (FIRST_COMPLETED, ALL_COMPLETED) + + remaining = list(futures) + while remaining: + completed, waiting = await wait_completed(remaining) + if mode == FIRST_COMPLETED: + return completed, waiting + remaining = waiting + + assert mode == ALL_COMPLETED + return list(futures), [] + +async def gather(*futures: RestateDurableFuture[Any]) -> List[RestateDurableFuture[Any]]: + """ + Blocks until all futures are completed. + + Returns a list of all futures. + """ + completed, _ = await wait(*futures, mode=ALL_COMPLETED) + return completed + +async def as_completed(*futures: RestateDurableFuture[Any]): + """ + Returns an iterator that yields the futures as they are completed. + + example: + + async for future in as_completed(f1, f2, f3): + # do something with the completed future + print(await future) # prints the result of the future + + """ + remaining = list(futures) + while remaining: + completed, waiting = await wait_completed(remaining) + for f in completed: + yield f + remaining = waiting + +async def wait_completed(futures: List[RestateDurableFuture[Any]]) -> Tuple[List[RestateDurableFuture[Any]], List[RestateDurableFuture[Any]]]: + """ + Blocks until at least one of the futures is completed. + + Returns a tuple of two lists: the first list contains the futures that are completed, + the second list contains the futures that are not completed. + """ + if not futures: + return [], [] + handles: List[int] = [] + context: ServerInvocationContext | None = None + for f in futures: + if not isinstance(f, ServerDurableFuture): + raise TerminalError("All futures must SDK created futures.") + if context is None: + context = f.context + elif context is not f.context: + raise TerminalError("All futures must be created by the same SDK context.") + if f.is_completed(): + return [f], [] + handles.append(f.source_notification_handle) + + assert context is not None + await context.create_poll_or_cancel_coroutine(handles) + completed = [] + uncompleted = [] + for index, handle in enumerate(handles): + future = futures[index] + if context.vm.is_completed(handle): + completed.append(future) + else: + uncompleted.append(future) + return completed, uncompleted diff --git a/python/restate/context.py b/python/restate/context.py index 5c21cc4..5588874 100644 --- a/python/restate/context.py +++ b/python/restate/context.py @@ -33,6 +33,12 @@ class RestateDurableFuture(typing.Generic[T], Awaitable[T]): Represents a durable future. """ + @abc.abstractmethod + def is_completed(self) -> bool: + """ + Returns True if the future is completed, False otherwise. + """ + @abc.abstractmethod def __await__(self): pass diff --git a/python/restate/server_context.py b/python/restate/server_context.py index b2fcaef..3b99f51 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -36,42 +36,73 @@ O = TypeVar('O') - - - class ServerDurableFuture(RestateDurableFuture[T]): """This class implements a durable future API""" value: T | None = None + error: TerminalError | None = None + state: typing.Literal["pending", "fulfilled", "rejected"] = "pending" - def __init__(self, handle: int, factory) -> None: + def __init__(self, context: "ServerInvocationContext", handle: int, awaitable_factory) -> None: super().__init__() - self.factory = factory - self.handle = handle + self.context = context + self.source_notification_handle = handle + self.awaitable_factory = awaitable_factory + self.state = "pending" + + + def is_completed(self): + match self.state: + case "pending": + return self.context.vm.is_completed(self.source_notification_handle) + case "fulfilled": + return True + case "rejected": + return True + def __await__(self): - task = asyncio.create_task(self.factory()) - return task.__await__() + async def await_point(): + match self.state: + case "pending": + try: + self.value = await self.awaitable_factory() + self.state = "fulfilled" + return self.value + except TerminalError as t: + self.error = t + self.state = "rejected" + raise t + case "fulfilled": + return self.value + case "rejected": + assert self.error is not None + raise self.error + + + return await_point().__await__() class ServerCallDurableFuture(RestateDurableCallFuture[T], ServerDurableFuture[T]): """This class implements a durable future but for calls""" _invocation_id: typing.Optional[str] = None - def __init__(self, result_handle: int, + def __init__(self, + context: "ServerInvocationContext", + result_handle: int, result_factory, invocation_id_handle: int, invocation_id_factory) -> None: - super().__init__(result_handle, result_factory) + super().__init__(context, result_handle, result_factory) self.invocation_id_handle = invocation_id_handle self.invocation_id_factory = invocation_id_factory + async def invocation_id(self) -> str: """Get the invocation id.""" if self._invocation_id is None: self._invocation_id = await self.invocation_id_factory() return self._invocation_id - class ServerSendHandle(SendHandle): """This class implements the send API""" _invocation_id: typing.Optional[str] @@ -90,29 +121,29 @@ async def invocation_id(self) -> str: self._invocation_id = res return res - - async def async_value(n: Callable[[], T]) -> T: """convert a simple value to a coroutine.""" return n() - class ServerDurablePromise(DurablePromise): """This class implements a durable promise API""" - def __init__(self, server_context, name, serde) -> None: + def __init__(self, server_context: "ServerInvocationContext", name, serde) -> None: super().__init__(name=name, serde=JsonSerde() if serde is None else serde) self.server_context = server_context def value(self) -> Awaitable[Any]: vm: VMWrapper = self.server_context.vm handle = vm.sys_get_promise(self.name) - coro = self.server_context.create_poll_or_cancel_coroutine(handle) + coro = self.server_context.create_poll_or_cancel_coroutine([handle]) serde = self.serde assert serde is not None async def await_point(): - res = await coro + await coro + res = self.server_context.must_take_notification(handle) + if res is None: + return None return serde.deserialize(res) return await_point() @@ -122,23 +153,33 @@ def resolve(self, value: Any) -> Awaitable[None]: assert self.serde is not None value_buffer = self.serde.serialize(value) handle = vm.sys_complete_promise_success(self.name, value_buffer) - return self.server_context.create_poll_or_cancel_coroutine(handle) + + async def await_point(): + await self.server_context.create_poll_or_cancel_coroutine([handle]) + self.server_context.must_take_notification(handle) + + return await_point() def reject(self, message: str, code: int = 500) -> Awaitable[None]: vm: VMWrapper = self.server_context.vm py_failure = Failure(code=code, message=message) handle = vm.sys_complete_promise_failure(self.name, py_failure) - return self.server_context.create_poll_or_cancel_coroutine(handle) + + async def await_point(): + await self.server_context.create_poll_or_cancel_coroutine([handle]) + self.server_context.must_take_notification(handle) + + return await_point() def peek(self) -> Awaitable[Any | None]: vm: VMWrapper = self.server_context.vm handle = vm.sys_peek_promise(self.name) - coro = self.server_context.create_poll_or_cancel_coroutine(handle) serde = self.serde assert serde is not None async def await_point(): - res = await coro + await self.server_context.create_poll_or_cancel_coroutine([handle]) + res = self.server_context.must_take_notification(handle) if res is None: return None return serde.deserialize(res) @@ -246,24 +287,17 @@ def must_take_notification(self, handle): return res - async def create_poll_or_cancel_coroutine(self, handle) -> bytes | None: + async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> None: """Create a coroutine to poll the handle.""" await self.take_and_send_output() while True: - if self.vm.is_completed(handle): - # Handle is completed - return self.must_take_notification(handle) - - # Nothing ready yet, let's try to make some progress - do_progress_response = self.vm.do_progress([handle]) + do_progress_response = self.vm.do_progress(handles) if isinstance(do_progress_response, DoProgressAnyCompleted): - # One of the handles completed, we can continue - continue + # One of the handles completed + return if isinstance(do_progress_response, DoProgressCancelSignalReceived): - # Raise cancel signal raise TerminalError("cancelled", 409) if isinstance(do_progress_response, DoProgressReadFromInput): - # We need to read from input chunk = await self.receive() if chunk.get('body', None) is not None: assert isinstance(chunk['body'], bytes) @@ -280,12 +314,13 @@ def create_df(self, handle: int, serde: Serde[T] | None = None) -> ServerDurable """Create a durable future.""" async def transform(): - res = await self.create_poll_or_cancel_coroutine(handle) + await self.create_poll_or_cancel_coroutine([handle]) + res = self.must_take_notification(handle) if res is None or serde is None: return res return serde.deserialize(res) - return ServerDurableFuture(handle, transform) + return ServerDurableFuture(self, handle, transform) @@ -293,15 +328,17 @@ def create_call_df(self, handle: int, invocation_id_handle: int, serde: Serde[T] """Create a durable future.""" async def transform(): - res = await self.create_poll_or_cancel_coroutine(handle) + await self.create_poll_or_cancel_coroutine([handle]) + res = self.must_take_notification(handle) if res is None or serde is None: return res return serde.deserialize(res) - def inv_id_factory(): - return self.create_poll_or_cancel_coroutine(invocation_id_handle) + async def inv_id_factory(): + await self.create_poll_or_cancel_coroutine([invocation_id_handle]) + return self.must_take_notification(invocation_id_handle) - return ServerCallDurableFuture(handle, transform, invocation_id_handle, inv_id_factory) + return ServerCallDurableFuture(self, handle, transform, invocation_id_handle, inv_id_factory) def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]]: