diff --git a/python/restate/__init__.py b/python/restate/__init__.py index c23b62b..045c33c 100644 --- a/python/restate/__init__.py +++ b/python/restate/__init__.py @@ -20,8 +20,8 @@ from .context import Context, ObjectContext, ObjectSharedContext from .context import WorkflowContext, WorkflowSharedContext from .context import DurablePromise, RestateDurableFuture, RestateDurableCallFuture, SendHandle -from .combinators import wait, gather, as_completed, ALL_COMPLETED, FIRST_COMPLETED from .exceptions import TerminalError +from .asyncio import as_completed, gather, wait_completed from .endpoint import app @@ -51,9 +51,7 @@ def test_harness(app, follow_logs = False, restate_image = ""): # type: ignore "TerminalError", "app", "test_harness", - "wait", "gather", "as_completed", - "ALL_COMPLETED", - "FIRST_COMPLETED", + "wait_completed", ] diff --git a/python/restate/combinators.py b/python/restate/asyncio.py similarity index 61% rename from python/restate/combinators.py rename to python/restate/asyncio.py index a0673e8..0fc516e 100644 --- a/python/restate/combinators.py +++ b/python/restate/asyncio.py @@ -17,54 +17,15 @@ from restate.context import RestateDurableFuture from restate.server_context import ServerDurableFuture, ServerInvocationContext -FIRST_COMPLETED = 1 -ALL_COMPLETED = 2 - -async def wait(*futures: RestateDurableFuture[Any], mode: int = FIRST_COMPLETED) -> Tuple[List[RestateDurableFuture[Any]], List[RestateDurableFuture[Any]]]: - """ - Blocks until at least one of the futures/all of the futures are completed. - - Returns a tuple of two lists: the first list contains the futures that are completed, - the second list contains the futures that are not completed. - - The mode parameter can be either FIRST_COMPLETED or ALL_COMPLETED. - Using FIRST_COMPLETED will return as soon as one of the futures is completed. - Using ALL_COMPLETED will return only when all futures are completed. - - examples: - - completed, waiting = await wait(f1, f2, f3, mode=FIRST_COMPLETED) - for completed_future in completed: - # do something with the completed future - print(await completed_future) # prints the result of the future - - or - - completed, waiting = await wait(f1, f2, f3, mode=ALL_COMPLETED) - assert waiting == [] - - - """ - assert mode in (FIRST_COMPLETED, ALL_COMPLETED) - - remaining = list(futures) - while remaining: - completed, waiting = await wait_completed(remaining) - if mode == FIRST_COMPLETED: - return completed, waiting - remaining = waiting - - assert mode == ALL_COMPLETED - return list(futures), [] - async def gather(*futures: RestateDurableFuture[Any]) -> List[RestateDurableFuture[Any]]: """ Blocks until all futures are completed. Returns a list of all futures. """ - completed, _ = await wait(*futures, mode=ALL_COMPLETED) - return completed + async for _ in as_completed(*futures): + pass + return list(futures) async def as_completed(*futures: RestateDurableFuture[Any]): """ @@ -79,12 +40,12 @@ async def as_completed(*futures: RestateDurableFuture[Any]): """ remaining = list(futures) while remaining: - completed, waiting = await wait_completed(remaining) + completed, waiting = await wait_completed(*remaining) for f in completed: yield f remaining = waiting -async def wait_completed(futures: List[RestateDurableFuture[Any]]) -> Tuple[List[RestateDurableFuture[Any]], List[RestateDurableFuture[Any]]]: +async def wait_completed(*args: RestateDurableFuture[Any]) -> Tuple[List[RestateDurableFuture[Any]], List[RestateDurableFuture[Any]]]: """ Blocks until at least one of the futures is completed. @@ -95,6 +56,7 @@ async def wait_completed(futures: List[RestateDurableFuture[Any]]) -> Tuple[List context: ServerInvocationContext | None = None completed = [] uncompleted = [] + futures = list(args) if not futures: return [], [] @@ -108,7 +70,7 @@ async def wait_completed(futures: List[RestateDurableFuture[Any]]) -> Tuple[List if f.is_completed(): completed.append(f) else: - handles.append(f.source_notification_handle) + handles.append(f.handle) uncompleted.append(f) if completed: diff --git a/python/restate/context.py b/python/restate/context.py index c2dfec9..022a8f6 100644 --- a/python/restate/context.py +++ b/python/restate/context.py @@ -33,22 +33,10 @@ class RestateDurableFuture(typing.Generic[T], Awaitable[T]): Represents a durable future. """ - @abc.abstractmethod - def is_completed(self) -> bool: - """ - Returns True if the future is completed, False otherwise. - """ - @abc.abstractmethod def __await__(self): pass - @abc.abstractmethod - def map_value(self, mapper: Callable[[T], O]) -> 'RestateDurableFuture[O]': - """ - Maps the value of the future using the given function. - """ - # pylint: disable=R0903 class RestateDurableCallFuture(RestateDurableFuture[T]): @@ -63,7 +51,6 @@ async def invocation_id(self) -> str: """ - @dataclass class Request: """ diff --git a/python/restate/server_context.py b/python/restate/server_context.py index 8956250..7e577fb 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -35,98 +35,85 @@ I = TypeVar('I') O = TypeVar('O') +class LazyFuture: + """ + Creates a task lazily, and allows multiple awaiters to the same coroutine. + The async_def will be executed at most 1 times. (0 if __await__ or get() not called) + """ + __slots__ = ['async_def', 'task'] + + def __init__(self, async_def: Callable[[], typing.Coroutine[Any, Any, T]]) -> None: + assert async_def is not None + self.async_def = async_def + self.task: asyncio.Task | None = None + + def done(self): + """ + check if completed + """ + return self.task is not None and self.task.done() + + async def get(self) -> T: + """Get the value of the future.""" + if self.task is None: + self.task = asyncio.create_task(self.async_def()) + + return await self.task + + def __await__(self): + return self.get().__await__() class ServerDurableFuture(RestateDurableFuture[T]): """This class implements a durable future API""" - value: T | None = None - error: TerminalError | None = None - state: typing.Literal["pending", "fulfilled", "rejected"] = "pending" - def __init__(self, context: "ServerInvocationContext", handle: int, awaitable_factory) -> None: + def __init__(self, context: "ServerInvocationContext", handle: int, async_def) -> None: super().__init__() self.context = context - self.source_notification_handle = handle - self.awaitable_factory = awaitable_factory - self.state = "pending" - + self.handle = handle + self.future = LazyFuture(async_def) def is_completed(self): - match self.state: - case "pending": - return self.context.vm.is_completed(self.source_notification_handle) - case "fulfilled": - return True - case "rejected": - return True - - def map_value(self, mapper: Callable[[T], O]) -> RestateDurableFuture[O]: - """Map the value of the future.""" - async def mapper_coro(): - return mapper(await self) - - return ServerDurableFuture(self.context, self.source_notification_handle, mapper_coro) - + """ + A future is completed, either it was physically completed and its value has been collected. + OR it might not yet physically completed (i.e. the async_def didn't finish yet) BUT our VM + already has a completion value for it. + """ + return self.future.done() or self.context.vm.is_completed(self.handle) def __await__(self): - - async def await_point(): - match self.state: - case "pending": - try: - self.value = await self.awaitable_factory() - self.state = "fulfilled" - return self.value - except TerminalError as t: - self.error = t - self.state = "rejected" - raise t - case "fulfilled": - return self.value - case "rejected": - assert self.error is not None - raise self.error - - - return await_point().__await__() + return self.future.__await__() class ServerCallDurableFuture(RestateDurableCallFuture[T], ServerDurableFuture[T]): """This class implements a durable future but for calls""" - _invocation_id: typing.Optional[str] = None def __init__(self, context: "ServerInvocationContext", result_handle: int, - result_factory, - invocation_id_handle: int, - invocation_id_factory) -> None: - super().__init__(context, result_handle, result_factory) - self.invocation_id_handle = invocation_id_handle - self.invocation_id_factory = invocation_id_factory - + result_async_def, + invocation_id_async_def) -> None: + super().__init__(context, result_handle, result_async_def) + self.invocation_id_future = LazyFuture(invocation_id_async_def) async def invocation_id(self) -> str: """Get the invocation id.""" - if self._invocation_id is None: - self._invocation_id = await self.invocation_id_factory() - return self._invocation_id + return await self.invocation_id_future.get() class ServerSendHandle(SendHandle): """This class implements the send API""" - _invocation_id: typing.Optional[str] - def __init__(self, context, handle: int) -> None: + def __init__(self, context: "ServerInvocationContext", handle: int) -> None: super().__init__() - self.handle = handle - self.context = context - self._invocation_id = None + + async def coro(): + if not context.vm.is_completed(handle): + await context.create_poll_or_cancel_coroutine([handle]) + return context.must_take_notification(handle) + + self.future = LazyFuture(coro) async def invocation_id(self) -> str: """Get the invocation id.""" - if self._invocation_id is not None: - return self._invocation_id - res = await self.context.create_poll_or_cancel_coroutine(self.handle) - self._invocation_id = res - return res + return await self.future async def async_value(n: Callable[[], T]) -> T: """convert a simple value to a coroutine.""" @@ -334,6 +321,7 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No continue if isinstance(do_progress_response, DoProgressExecuteRun): fn = self.run_coros_to_execute[do_progress_response.handle] + del self.run_coros_to_execute[do_progress_response.handle] assert fn is not None async def wrapper(f): @@ -346,11 +334,12 @@ async def wrapper(f): if isinstance(do_progress_response, DoWaitPendingRun): await self.sync_point.wait() - def create_df(self, handle: int, serde: Serde[T] | None = None) -> ServerDurableFuture[T]: + def create_future(self, handle: int, serde: Serde[T] | None = None) -> ServerDurableFuture[T]: """Create a durable future.""" async def transform(): - await self.create_poll_or_cancel_coroutine([handle]) + if not self.vm.is_completed(handle): + await self.create_poll_or_cancel_coroutine([handle]) res = self.must_take_notification(handle) if res is None or serde is None: return res @@ -359,30 +348,31 @@ async def transform(): return ServerDurableFuture(self, handle, transform) - - def create_call_df(self, handle: int, invocation_id_handle: int, serde: Serde[T] | None = None) -> ServerCallDurableFuture[T]: + def create_call_future(self, handle: int, invocation_id_handle: int, serde: Serde[T] | None = None) -> ServerCallDurableFuture[T]: """Create a durable future.""" async def transform(): - await self.create_poll_or_cancel_coroutine([handle]) + if not self.vm.is_completed(handle): + await self.create_poll_or_cancel_coroutine([handle]) res = self.must_take_notification(handle) if res is None or serde is None: return res return serde.deserialize(res) async def inv_id_factory(): - await self.create_poll_or_cancel_coroutine([invocation_id_handle]) + if not self.vm.is_completed(invocation_id_handle): + await self.create_poll_or_cancel_coroutine([invocation_id_handle]) return self.must_take_notification(invocation_id_handle) - return ServerCallDurableFuture(self, handle, transform, invocation_id_handle, inv_id_factory) + return ServerCallDurableFuture(self, handle, transform, inv_id_factory) def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]]: handle = self.vm.sys_get_state(name) - return self.create_df(handle, serde) # type: ignore + return self.create_future(handle, serde) # type: ignore def state_keys(self) -> Awaitable[List[str]]: - return self.create_df(self.vm.sys_get_state_keys()) # type: ignore + return self.create_future(self.vm.sys_get_state_keys()) # type: ignore def set(self, name: str, value: T, serde: Serde[T] = JsonSerde()) -> None: """Set the value associated with the given name.""" @@ -446,13 +436,13 @@ def run(self, self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, action, serde, max_attempts, max_retry_duration) # Prepare response coroutine - return self.create_df(handle, serde) # type: ignore + return self.create_future(handle, serde) # type: ignore def sleep(self, delta: timedelta) -> RestateDurableFuture[None]: # convert timedelta to milliseconds millis = int(delta.total_seconds() * 1000) - return self.create_df(self.vm.sys_sleep(millis)) # type: ignore + return self.create_future(self.vm.sys_sleep(millis)) # type: ignore def do_call(self, tpe: Callable[[Any, I], Awaitable[O]], @@ -501,7 +491,7 @@ def do_raw_call(self, idempotency_key=idempotency_key, headers=headers) - return self.create_call_df(handle=handle.result_handle, + return self.create_call_future(handle=handle.result_handle, invocation_id_handle=handle.invocation_id_handle, serde=output_serde) @@ -582,7 +572,7 @@ def awakeable(self, serde: typing.Optional[Serde[I]] = JsonSerde()) -> typing.Tuple[str, RestateDurableFuture[Any]]: assert serde is not None name, handle = self.vm.sys_awakeable() - return name, self.create_df(handle, serde) + return name, self.create_future(handle, serde) def resolve_awakeable(self, name: str, @@ -613,4 +603,4 @@ def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) - raise ValueError("invocation_id cannot be None") assert serde is not None handle = self.vm.attach_invocation(invocation_id) - return self.create_df(handle, serde) + return self.create_future(handle, serde)