Skip to content

Commit f056467

Browse files
committed
Add restate future combinators
1 parent 5212013 commit f056467

File tree

4 files changed

+155
-14
lines changed

4 files changed

+155
-14
lines changed

python/restate/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from .context import Context, ObjectContext, ObjectSharedContext
2121
from .context import WorkflowContext, WorkflowSharedContext
2222
from .context import DurablePromise, RestateDurableFuture, RestateDurableCallFuture, SendHandle
23+
from .combinators import wait, gather, as_completed, ALL_COMPLETED, FIRST_COMPLETED
24+
from .exceptions import TerminalError
2325

2426
from .endpoint import app
2527

@@ -46,6 +48,12 @@ def test_harness(app, follow_logs = False, restate_image = ""): # type: ignore
4648
"RestateDurableFuture",
4749
"RestateDurableCallFuture",
4850
"SendHandle",
51+
"TerminalError",
4952
"app",
5053
"test_harness",
54+
"wait",
55+
"gather",
56+
"as_completed",
57+
"ALL_COMPLETED",
58+
"FIRST_COMPLETED",
5159
]

python/restate/combinators.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
#
2+
# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
3+
#
4+
# This file is part of the Restate SDK for Python,
5+
# which is released under the MIT license.
6+
#
7+
# You can find a copy of the license in file LICENSE in the root
8+
# directory of this repository or package, or at
9+
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10+
#
11+
# pylint: disable=R0913,C0301,R0917
12+
# pylint: disable=line-too-long
13+
"""combines multiple futures into a single future"""
14+
15+
from typing import Any, AsyncGenerator, List, Tuple
16+
from restate.exceptions import TerminalError
17+
from restate.context import RestateDurableFuture
18+
from restate.server_context import ServerDurableFuture, ServerInvocationContext
19+
20+
FIRST_COMPLETED = 1
21+
ALL_COMPLETED = 2
22+
23+
async def wait(*futures: RestateDurableFuture[Any], mode: int = FIRST_COMPLETED) -> Tuple[List[RestateDurableFuture[Any]], List[RestateDurableFuture[Any]]]:
24+
"""
25+
Blocks until at least one of the futures/all of the futures are completed.
26+
27+
Returns a tuple of two lists: the first list contains the futures that are completed,
28+
the second list contains the futures that are not completed.
29+
30+
The mode parameter can be either FIRST_COMPLETED or ALL_COMPLETED.
31+
Using FIRST_COMPLETED will return as soon as one of the futures is completed.
32+
Using ALL_COMPLETED will return only when all futures are completed.
33+
34+
examples:
35+
36+
completed, waiting = await wait(f1, f2, f3, mode=FIRST_COMPLETED)
37+
for completed_future in completed:
38+
# do something with the completed future
39+
print(await completed_future) # prints the result of the future
40+
41+
or
42+
43+
completed, waiting = await wait(f1, f2, f3, mode=ALL_COMPLETED)
44+
assert waiting == []
45+
46+
47+
"""
48+
assert mode in (FIRST_COMPLETED, ALL_COMPLETED)
49+
50+
remaining = list(futures)
51+
while remaining:
52+
completed, waiting = await wait_completed(remaining)
53+
if mode == FIRST_COMPLETED:
54+
return completed, waiting
55+
remaining = waiting
56+
57+
assert mode == ALL_COMPLETED
58+
return list(futures), []
59+
60+
async def gather(*futures: RestateDurableFuture[Any]) -> List[RestateDurableFuture[Any]]:
61+
"""
62+
Blocks until all futures are completed.
63+
64+
Returns a list of all futures.
65+
"""
66+
completed, _ = await wait(*futures, mode=ALL_COMPLETED)
67+
return completed
68+
69+
async def as_completed(*futures: RestateDurableFuture[Any]) -> AsyncGenerator[RestateDurableFuture[Any]]:
70+
"""
71+
Returns an iterator that yields the futures as they are completed.
72+
73+
example:
74+
75+
async for future in as_completed(f1, f2, f3):
76+
# do something with the completed future
77+
print(await future) # prints the result of the future
78+
79+
"""
80+
remaining = list(futures)
81+
while remaining:
82+
completed, waiting = await wait_completed(remaining)
83+
for f in completed:
84+
yield f
85+
remaining = waiting
86+
87+
async def wait_completed(futures: List[RestateDurableFuture[Any]]) -> Tuple[List[RestateDurableFuture[Any]], List[RestateDurableFuture[Any]]]:
88+
"""
89+
Blocks until at least one of the futures is completed.
90+
91+
Returns a tuple of two lists: the first list contains the futures that are completed,
92+
the second list contains the futures that are not completed.
93+
"""
94+
if not futures:
95+
return [], []
96+
handles: List[int] = []
97+
context: ServerInvocationContext | None = None
98+
for f in futures:
99+
if not isinstance(f, ServerDurableFuture):
100+
raise TerminalError("All futures must SDK created futures.")
101+
if context is None:
102+
context = f.context
103+
elif context is not f.context:
104+
raise TerminalError("All futures must be created by the same SDK context.")
105+
if f.is_completed():
106+
return [f], []
107+
handles.append(f.source_notification_handle)
108+
109+
assert context is not None
110+
await context.create_poll_or_cancel_coroutine(handles)
111+
completed = []
112+
uncompleted = []
113+
for index, handle in enumerate(handles):
114+
future = futures[index]
115+
if context.vm.is_completed(handle):
116+
completed.append(future)
117+
else:
118+
uncompleted.append(future)
119+
return completed, uncompleted
120+
121+

python/restate/context.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ class RestateDurableFuture(typing.Generic[T], Awaitable[T]):
3333
Represents a durable future.
3434
"""
3535

36+
@abc.abstractmethod
37+
def is_completed(self) -> bool:
38+
"""
39+
Returns True if the future is completed, False otherwise.
40+
"""
41+
3642
@abc.abstractmethod
3743
def __await__(self):
3844
pass

python/restate/server_context.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,27 +36,37 @@
3636
O = TypeVar('O')
3737

3838

39-
4039
class ServerDurableFuture(RestateDurableFuture[T]):
4140
"""This class implements a durable future API"""
4241
value: T | None = None
4342
error: TerminalError | None = None
4443
state: typing.Literal["pending", "fulfilled", "rejected"] = "pending"
4544

46-
def __init__(self, handle: int, factory, ctx: "ServerInvocationContext") -> None:
45+
def __init__(self, context: "ServerInvocationContext", handle: int, awaitable_factory) -> None:
4746
super().__init__()
48-
self.factory = factory
49-
self.handle = handle
50-
self.context = ctx
47+
self.context = context
48+
self.source_notification_handle = handle
49+
self.awaitable_factory = awaitable_factory
5150
self.state = "pending"
5251

52+
53+
def is_completed(self):
54+
match self.state:
55+
case "pending":
56+
return self.context.vm.is_completed(self.source_notification_handle)
57+
case "fulfilled":
58+
return True
59+
case "rejected":
60+
return True
61+
62+
5363
def __await__(self):
5464

5565
async def await_point():
5666
match self.state:
5767
case "pending":
5868
try:
59-
self.value = await self.factory()
69+
self.value = await self.awaitable_factory()
6070
self.state = "fulfilled"
6171
return self.value
6272
except TerminalError as t:
@@ -71,31 +81,28 @@ async def await_point():
7181

7282

7383
return await_point().__await__()
74-
#task = asyncio.create_task(self.factory())
75-
#return task.__await__()
76-
7784

7885
class ServerCallDurableFuture(RestateDurableCallFuture[T], ServerDurableFuture[T]):
7986
"""This class implements a durable future but for calls"""
8087
_invocation_id: typing.Optional[str] = None
8188

8289
def __init__(self,
83-
ctx: "ServerInvocationContext",
90+
context: "ServerInvocationContext",
8491
result_handle: int,
8592
result_factory,
8693
invocation_id_handle: int,
8794
invocation_id_factory) -> None:
88-
super().__init__(result_handle, result_factory, ctx)
95+
super().__init__(context, result_handle, result_factory)
8996
self.invocation_id_handle = invocation_id_handle
9097
self.invocation_id_factory = invocation_id_factory
9198

99+
92100
async def invocation_id(self) -> str:
93101
"""Get the invocation id."""
94102
if self._invocation_id is None:
95103
self._invocation_id = await self.invocation_id_factory()
96104
return self._invocation_id
97105

98-
99106
class ServerSendHandle(SendHandle):
100107
"""This class implements the send API"""
101108
_invocation_id: typing.Optional[str]
@@ -118,7 +125,6 @@ async def async_value(n: Callable[[], T]) -> T:
118125
"""convert a simple value to a coroutine."""
119126
return n()
120127

121-
122128
class ServerDurablePromise(DurablePromise):
123129
"""This class implements a durable promise API"""
124130

@@ -314,7 +320,7 @@ async def transform():
314320
return res
315321
return serde.deserialize(res)
316322

317-
return ServerDurableFuture(handle, transform, self)
323+
return ServerDurableFuture(self, handle, transform)
318324

319325

320326

0 commit comments

Comments
 (0)