Skip to content

Rename combinators to asyncio #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions python/restate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from .context import Context, ObjectContext, ObjectSharedContext
from .context import WorkflowContext, WorkflowSharedContext
from .context import DurablePromise, RestateDurableFuture, RestateDurableCallFuture, SendHandle
from .combinators import wait, gather, as_completed, ALL_COMPLETED, FIRST_COMPLETED
from .exceptions import TerminalError
from .asyncio import as_completed, gather, wait_completed

from .endpoint import app

Expand Down Expand Up @@ -51,9 +51,7 @@ def test_harness(app, follow_logs = False, restate_image = ""): # type: ignore
"TerminalError",
"app",
"test_harness",
"wait",
"gather",
"as_completed",
"ALL_COMPLETED",
"FIRST_COMPLETED",
"wait_completed",
]
52 changes: 7 additions & 45 deletions python/restate/combinators.py → python/restate/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,54 +17,15 @@
from restate.context import RestateDurableFuture
from restate.server_context import ServerDurableFuture, ServerInvocationContext

FIRST_COMPLETED = 1
ALL_COMPLETED = 2

async def wait(*futures: RestateDurableFuture[Any], mode: int = FIRST_COMPLETED) -> Tuple[List[RestateDurableFuture[Any]], List[RestateDurableFuture[Any]]]:
"""
Blocks until at least one of the futures/all of the futures are completed.

Returns a tuple of two lists: the first list contains the futures that are completed,
the second list contains the futures that are not completed.

The mode parameter can be either FIRST_COMPLETED or ALL_COMPLETED.
Using FIRST_COMPLETED will return as soon as one of the futures is completed.
Using ALL_COMPLETED will return only when all futures are completed.

examples:

completed, waiting = await wait(f1, f2, f3, mode=FIRST_COMPLETED)
for completed_future in completed:
# do something with the completed future
print(await completed_future) # prints the result of the future

or

completed, waiting = await wait(f1, f2, f3, mode=ALL_COMPLETED)
assert waiting == []


"""
assert mode in (FIRST_COMPLETED, ALL_COMPLETED)

remaining = list(futures)
while remaining:
completed, waiting = await wait_completed(remaining)
if mode == FIRST_COMPLETED:
return completed, waiting
remaining = waiting

assert mode == ALL_COMPLETED
return list(futures), []

async def gather(*futures: RestateDurableFuture[Any]) -> List[RestateDurableFuture[Any]]:
"""
Blocks until all futures are completed.

Returns a list of all futures.
"""
completed, _ = await wait(*futures, mode=ALL_COMPLETED)
return completed
async for _ in as_completed(*futures):
pass
return list(futures)

async def as_completed(*futures: RestateDurableFuture[Any]):
"""
Expand All @@ -79,12 +40,12 @@ async def as_completed(*futures: RestateDurableFuture[Any]):
"""
remaining = list(futures)
while remaining:
completed, waiting = await wait_completed(remaining)
completed, waiting = await wait_completed(*remaining)
for f in completed:
yield f
remaining = waiting

async def wait_completed(futures: List[RestateDurableFuture[Any]]) -> Tuple[List[RestateDurableFuture[Any]], List[RestateDurableFuture[Any]]]:
async def wait_completed(*args: RestateDurableFuture[Any]) -> Tuple[List[RestateDurableFuture[Any]], List[RestateDurableFuture[Any]]]:
"""
Blocks until at least one of the futures is completed.

Expand All @@ -95,6 +56,7 @@ async def wait_completed(futures: List[RestateDurableFuture[Any]]) -> Tuple[List
context: ServerInvocationContext | None = None
completed = []
uncompleted = []
futures = list(args)

if not futures:
return [], []
Expand All @@ -108,7 +70,7 @@ async def wait_completed(futures: List[RestateDurableFuture[Any]]) -> Tuple[List
if f.is_completed():
completed.append(f)
else:
handles.append(f.source_notification_handle)
handles.append(f.handle)
uncompleted.append(f)

if completed:
Expand Down
13 changes: 0 additions & 13 deletions python/restate/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,10 @@ class RestateDurableFuture(typing.Generic[T], Awaitable[T]):
Represents a durable future.
"""

@abc.abstractmethod
def is_completed(self) -> bool:
"""
Returns True if the future is completed, False otherwise.
"""

@abc.abstractmethod
def __await__(self):
pass

@abc.abstractmethod
def map_value(self, mapper: Callable[[T], O]) -> 'RestateDurableFuture[O]':
"""
Maps the value of the future using the given function.
"""


# pylint: disable=R0903
class RestateDurableCallFuture(RestateDurableFuture[T]):
Expand All @@ -63,7 +51,6 @@ async def invocation_id(self) -> str:
"""



@dataclass
class Request:
"""
Expand Down
146 changes: 68 additions & 78 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,98 +35,85 @@
I = TypeVar('I')
O = TypeVar('O')

class LazyFuture:
"""
Creates a task lazily, and allows multiple awaiters to the same coroutine.
The async_def will be executed at most 1 times. (0 if __await__ or get() not called)
"""
__slots__ = ['async_def', 'task']

def __init__(self, async_def: Callable[[], typing.Coroutine[Any, Any, T]]) -> None:
assert async_def is not None
self.async_def = async_def
self.task: asyncio.Task | None = None

def done(self):
"""
check if completed
"""
return self.task is not None and self.task.done()

async def get(self) -> T:
"""Get the value of the future."""
if self.task is None:
self.task = asyncio.create_task(self.async_def())

return await self.task

def __await__(self):
return self.get().__await__()

class ServerDurableFuture(RestateDurableFuture[T]):
"""This class implements a durable future API"""
value: T | None = None
error: TerminalError | None = None
state: typing.Literal["pending", "fulfilled", "rejected"] = "pending"

def __init__(self, context: "ServerInvocationContext", handle: int, awaitable_factory) -> None:
def __init__(self, context: "ServerInvocationContext", handle: int, async_def) -> None:
super().__init__()
self.context = context
self.source_notification_handle = handle
self.awaitable_factory = awaitable_factory
self.state = "pending"

self.handle = handle
self.future = LazyFuture(async_def)

def is_completed(self):
match self.state:
case "pending":
return self.context.vm.is_completed(self.source_notification_handle)
case "fulfilled":
return True
case "rejected":
return True

def map_value(self, mapper: Callable[[T], O]) -> RestateDurableFuture[O]:
"""Map the value of the future."""
async def mapper_coro():
return mapper(await self)

return ServerDurableFuture(self.context, self.source_notification_handle, mapper_coro)

"""
A future is completed, either it was physically completed and its value has been collected.
OR it might not yet physically completed (i.e. the async_def didn't finish yet) BUT our VM
already has a completion value for it.
"""
return self.future.done() or self.context.vm.is_completed(self.handle)

def __await__(self):

async def await_point():
match self.state:
case "pending":
try:
self.value = await self.awaitable_factory()
self.state = "fulfilled"
return self.value
except TerminalError as t:
self.error = t
self.state = "rejected"
raise t
case "fulfilled":
return self.value
case "rejected":
assert self.error is not None
raise self.error


return await_point().__await__()
return self.future.__await__()

class ServerCallDurableFuture(RestateDurableCallFuture[T], ServerDurableFuture[T]):
"""This class implements a durable future but for calls"""
_invocation_id: typing.Optional[str] = None

def __init__(self,
context: "ServerInvocationContext",
result_handle: int,
result_factory,
invocation_id_handle: int,
invocation_id_factory) -> None:
super().__init__(context, result_handle, result_factory)
self.invocation_id_handle = invocation_id_handle
self.invocation_id_factory = invocation_id_factory

result_async_def,
invocation_id_async_def) -> None:
super().__init__(context, result_handle, result_async_def)
self.invocation_id_future = LazyFuture(invocation_id_async_def)

async def invocation_id(self) -> str:
"""Get the invocation id."""
if self._invocation_id is None:
self._invocation_id = await self.invocation_id_factory()
return self._invocation_id
return await self.invocation_id_future.get()

class ServerSendHandle(SendHandle):
"""This class implements the send API"""
_invocation_id: typing.Optional[str]

def __init__(self, context, handle: int) -> None:
def __init__(self, context: "ServerInvocationContext", handle: int) -> None:
super().__init__()
self.handle = handle
self.context = context
self._invocation_id = None

async def coro():
if not context.vm.is_completed(handle):
await context.create_poll_or_cancel_coroutine([handle])
return context.must_take_notification(handle)

self.future = LazyFuture(coro)

async def invocation_id(self) -> str:
"""Get the invocation id."""
if self._invocation_id is not None:
return self._invocation_id
res = await self.context.create_poll_or_cancel_coroutine(self.handle)
self._invocation_id = res
return res
return await self.future

async def async_value(n: Callable[[], T]) -> T:
"""convert a simple value to a coroutine."""
Expand Down Expand Up @@ -334,6 +321,7 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
continue
if isinstance(do_progress_response, DoProgressExecuteRun):
fn = self.run_coros_to_execute[do_progress_response.handle]
del self.run_coros_to_execute[do_progress_response.handle]
assert fn is not None

async def wrapper(f):
Expand All @@ -346,11 +334,12 @@ async def wrapper(f):
if isinstance(do_progress_response, DoWaitPendingRun):
await self.sync_point.wait()

def create_df(self, handle: int, serde: Serde[T] | None = None) -> ServerDurableFuture[T]:
def create_future(self, handle: int, serde: Serde[T] | None = None) -> ServerDurableFuture[T]:
"""Create a durable future."""

async def transform():
await self.create_poll_or_cancel_coroutine([handle])
if not self.vm.is_completed(handle):
await self.create_poll_or_cancel_coroutine([handle])
res = self.must_take_notification(handle)
if res is None or serde is None:
return res
Expand All @@ -359,30 +348,31 @@ async def transform():
return ServerDurableFuture(self, handle, transform)



def create_call_df(self, handle: int, invocation_id_handle: int, serde: Serde[T] | None = None) -> ServerCallDurableFuture[T]:
def create_call_future(self, handle: int, invocation_id_handle: int, serde: Serde[T] | None = None) -> ServerCallDurableFuture[T]:
"""Create a durable future."""

async def transform():
await self.create_poll_or_cancel_coroutine([handle])
if not self.vm.is_completed(handle):
await self.create_poll_or_cancel_coroutine([handle])
res = self.must_take_notification(handle)
if res is None or serde is None:
return res
return serde.deserialize(res)

async def inv_id_factory():
await self.create_poll_or_cancel_coroutine([invocation_id_handle])
if not self.vm.is_completed(invocation_id_handle):
await self.create_poll_or_cancel_coroutine([invocation_id_handle])
return self.must_take_notification(invocation_id_handle)

return ServerCallDurableFuture(self, handle, transform, invocation_id_handle, inv_id_factory)
return ServerCallDurableFuture(self, handle, transform, inv_id_factory)


def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]]:
handle = self.vm.sys_get_state(name)
return self.create_df(handle, serde) # type: ignore
return self.create_future(handle, serde) # type: ignore

def state_keys(self) -> Awaitable[List[str]]:
return self.create_df(self.vm.sys_get_state_keys()) # type: ignore
return self.create_future(self.vm.sys_get_state_keys()) # type: ignore

def set(self, name: str, value: T, serde: Serde[T] = JsonSerde()) -> None:
"""Set the value associated with the given name."""
Expand Down Expand Up @@ -446,13 +436,13 @@ def run(self,
self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, action, serde, max_attempts, max_retry_duration)

# Prepare response coroutine
return self.create_df(handle, serde) # type: ignore
return self.create_future(handle, serde) # type: ignore


def sleep(self, delta: timedelta) -> RestateDurableFuture[None]:
# convert timedelta to milliseconds
millis = int(delta.total_seconds() * 1000)
return self.create_df(self.vm.sys_sleep(millis)) # type: ignore
return self.create_future(self.vm.sys_sleep(millis)) # type: ignore

def do_call(self,
tpe: Callable[[Any, I], Awaitable[O]],
Expand Down Expand Up @@ -501,7 +491,7 @@ def do_raw_call(self,
idempotency_key=idempotency_key,
headers=headers)

return self.create_call_df(handle=handle.result_handle,
return self.create_call_future(handle=handle.result_handle,
invocation_id_handle=handle.invocation_id_handle,
serde=output_serde)

Expand Down Expand Up @@ -582,7 +572,7 @@ def awakeable(self,
serde: typing.Optional[Serde[I]] = JsonSerde()) -> typing.Tuple[str, RestateDurableFuture[Any]]:
assert serde is not None
name, handle = self.vm.sys_awakeable()
return name, self.create_df(handle, serde)
return name, self.create_future(handle, serde)

def resolve_awakeable(self,
name: str,
Expand Down Expand Up @@ -613,4 +603,4 @@ def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -
raise ValueError("invocation_id cannot be None")
assert serde is not None
handle = self.vm.attach_invocation(invocation_id)
return self.create_df(handle, serde)
return self.create_future(handle, serde)