Skip to content

Commit 37959cd

Browse files
authored
Rename combinators to asyncio (#59)
* Rename combinators to asyncio * Make the future handlers more robust
1 parent 032e826 commit 37959cd

File tree

4 files changed

+77
-140
lines changed

4 files changed

+77
-140
lines changed

python/restate/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from .context import Context, ObjectContext, ObjectSharedContext
2121
from .context import WorkflowContext, WorkflowSharedContext
2222
from .context import DurablePromise, RestateDurableFuture, RestateDurableCallFuture, SendHandle
23-
from .combinators import wait, gather, as_completed, ALL_COMPLETED, FIRST_COMPLETED
2423
from .exceptions import TerminalError
24+
from .asyncio import as_completed, gather, wait_completed
2525

2626
from .endpoint import app
2727

@@ -51,9 +51,7 @@ def test_harness(app, follow_logs = False, restate_image = ""): # type: ignore
5151
"TerminalError",
5252
"app",
5353
"test_harness",
54-
"wait",
5554
"gather",
5655
"as_completed",
57-
"ALL_COMPLETED",
58-
"FIRST_COMPLETED",
56+
"wait_completed",
5957
]

python/restate/combinators.py renamed to python/restate/asyncio.py

Lines changed: 7 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,54 +17,15 @@
1717
from restate.context import RestateDurableFuture
1818
from restate.server_context import ServerDurableFuture, ServerInvocationContext
1919

20-
FIRST_COMPLETED = 1
21-
ALL_COMPLETED = 2
22-
23-
async def wait(*futures: RestateDurableFuture[Any], mode: int = FIRST_COMPLETED) -> Tuple[List[RestateDurableFuture[Any]], List[RestateDurableFuture[Any]]]:
24-
"""
25-
Blocks until at least one of the futures/all of the futures are completed.
26-
27-
Returns a tuple of two lists: the first list contains the futures that are completed,
28-
the second list contains the futures that are not completed.
29-
30-
The mode parameter can be either FIRST_COMPLETED or ALL_COMPLETED.
31-
Using FIRST_COMPLETED will return as soon as one of the futures is completed.
32-
Using ALL_COMPLETED will return only when all futures are completed.
33-
34-
examples:
35-
36-
completed, waiting = await wait(f1, f2, f3, mode=FIRST_COMPLETED)
37-
for completed_future in completed:
38-
# do something with the completed future
39-
print(await completed_future) # prints the result of the future
40-
41-
or
42-
43-
completed, waiting = await wait(f1, f2, f3, mode=ALL_COMPLETED)
44-
assert waiting == []
45-
46-
47-
"""
48-
assert mode in (FIRST_COMPLETED, ALL_COMPLETED)
49-
50-
remaining = list(futures)
51-
while remaining:
52-
completed, waiting = await wait_completed(remaining)
53-
if mode == FIRST_COMPLETED:
54-
return completed, waiting
55-
remaining = waiting
56-
57-
assert mode == ALL_COMPLETED
58-
return list(futures), []
59-
6020
async def gather(*futures: RestateDurableFuture[Any]) -> List[RestateDurableFuture[Any]]:
6121
"""
6222
Blocks until all futures are completed.
6323
6424
Returns a list of all futures.
6525
"""
66-
completed, _ = await wait(*futures, mode=ALL_COMPLETED)
67-
return completed
26+
async for _ in as_completed(*futures):
27+
pass
28+
return list(futures)
6829

6930
async def as_completed(*futures: RestateDurableFuture[Any]):
7031
"""
@@ -79,12 +40,12 @@ async def as_completed(*futures: RestateDurableFuture[Any]):
7940
"""
8041
remaining = list(futures)
8142
while remaining:
82-
completed, waiting = await wait_completed(remaining)
43+
completed, waiting = await wait_completed(*remaining)
8344
for f in completed:
8445
yield f
8546
remaining = waiting
8647

87-
async def wait_completed(futures: List[RestateDurableFuture[Any]]) -> Tuple[List[RestateDurableFuture[Any]], List[RestateDurableFuture[Any]]]:
48+
async def wait_completed(*args: RestateDurableFuture[Any]) -> Tuple[List[RestateDurableFuture[Any]], List[RestateDurableFuture[Any]]]:
8849
"""
8950
Blocks until at least one of the futures is completed.
9051
@@ -95,6 +56,7 @@ async def wait_completed(futures: List[RestateDurableFuture[Any]]) -> Tuple[List
9556
context: ServerInvocationContext | None = None
9657
completed = []
9758
uncompleted = []
59+
futures = list(args)
9860

9961
if not futures:
10062
return [], []
@@ -108,7 +70,7 @@ async def wait_completed(futures: List[RestateDurableFuture[Any]]) -> Tuple[List
10870
if f.is_completed():
10971
completed.append(f)
11072
else:
111-
handles.append(f.source_notification_handle)
73+
handles.append(f.handle)
11274
uncompleted.append(f)
11375

11476
if completed:

python/restate/context.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,10 @@ class RestateDurableFuture(typing.Generic[T], Awaitable[T]):
3333
Represents a durable future.
3434
"""
3535

36-
@abc.abstractmethod
37-
def is_completed(self) -> bool:
38-
"""
39-
Returns True if the future is completed, False otherwise.
40-
"""
41-
4236
@abc.abstractmethod
4337
def __await__(self):
4438
pass
4539

46-
@abc.abstractmethod
47-
def map_value(self, mapper: Callable[[T], O]) -> 'RestateDurableFuture[O]':
48-
"""
49-
Maps the value of the future using the given function.
50-
"""
51-
5240

5341
# pylint: disable=R0903
5442
class RestateDurableCallFuture(RestateDurableFuture[T]):
@@ -63,7 +51,6 @@ async def invocation_id(self) -> str:
6351
"""
6452

6553

66-
6754
@dataclass
6855
class Request:
6956
"""

python/restate/server_context.py

Lines changed: 68 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -35,98 +35,85 @@
3535
I = TypeVar('I')
3636
O = TypeVar('O')
3737

38+
class LazyFuture:
39+
"""
40+
Creates a task lazily, and allows multiple awaiters to the same coroutine.
41+
The async_def will be executed at most 1 times. (0 if __await__ or get() not called)
42+
"""
43+
__slots__ = ['async_def', 'task']
44+
45+
def __init__(self, async_def: Callable[[], typing.Coroutine[Any, Any, T]]) -> None:
46+
assert async_def is not None
47+
self.async_def = async_def
48+
self.task: asyncio.Task | None = None
49+
50+
def done(self):
51+
"""
52+
check if completed
53+
"""
54+
return self.task is not None and self.task.done()
55+
56+
async def get(self) -> T:
57+
"""Get the value of the future."""
58+
if self.task is None:
59+
self.task = asyncio.create_task(self.async_def())
60+
61+
return await self.task
62+
63+
def __await__(self):
64+
return self.get().__await__()
3865

3966
class ServerDurableFuture(RestateDurableFuture[T]):
4067
"""This class implements a durable future API"""
41-
value: T | None = None
42-
error: TerminalError | None = None
43-
state: typing.Literal["pending", "fulfilled", "rejected"] = "pending"
4468

45-
def __init__(self, context: "ServerInvocationContext", handle: int, awaitable_factory) -> None:
69+
def __init__(self, context: "ServerInvocationContext", handle: int, async_def) -> None:
4670
super().__init__()
4771
self.context = context
48-
self.source_notification_handle = handle
49-
self.awaitable_factory = awaitable_factory
50-
self.state = "pending"
51-
72+
self.handle = handle
73+
self.future = LazyFuture(async_def)
5274

5375
def is_completed(self):
54-
match self.state:
55-
case "pending":
56-
return self.context.vm.is_completed(self.source_notification_handle)
57-
case "fulfilled":
58-
return True
59-
case "rejected":
60-
return True
61-
62-
def map_value(self, mapper: Callable[[T], O]) -> RestateDurableFuture[O]:
63-
"""Map the value of the future."""
64-
async def mapper_coro():
65-
return mapper(await self)
66-
67-
return ServerDurableFuture(self.context, self.source_notification_handle, mapper_coro)
68-
76+
"""
77+
A future is completed, either it was physically completed and its value has been collected.
78+
OR it might not yet physically completed (i.e. the async_def didn't finish yet) BUT our VM
79+
already has a completion value for it.
80+
"""
81+
return self.future.done() or self.context.vm.is_completed(self.handle)
6982

7083
def __await__(self):
71-
72-
async def await_point():
73-
match self.state:
74-
case "pending":
75-
try:
76-
self.value = await self.awaitable_factory()
77-
self.state = "fulfilled"
78-
return self.value
79-
except TerminalError as t:
80-
self.error = t
81-
self.state = "rejected"
82-
raise t
83-
case "fulfilled":
84-
return self.value
85-
case "rejected":
86-
assert self.error is not None
87-
raise self.error
88-
89-
90-
return await_point().__await__()
84+
return self.future.__await__()
9185

9286
class ServerCallDurableFuture(RestateDurableCallFuture[T], ServerDurableFuture[T]):
9387
"""This class implements a durable future but for calls"""
94-
_invocation_id: typing.Optional[str] = None
9588

9689
def __init__(self,
9790
context: "ServerInvocationContext",
9891
result_handle: int,
99-
result_factory,
100-
invocation_id_handle: int,
101-
invocation_id_factory) -> None:
102-
super().__init__(context, result_handle, result_factory)
103-
self.invocation_id_handle = invocation_id_handle
104-
self.invocation_id_factory = invocation_id_factory
105-
92+
result_async_def,
93+
invocation_id_async_def) -> None:
94+
super().__init__(context, result_handle, result_async_def)
95+
self.invocation_id_future = LazyFuture(invocation_id_async_def)
10696

10797
async def invocation_id(self) -> str:
10898
"""Get the invocation id."""
109-
if self._invocation_id is None:
110-
self._invocation_id = await self.invocation_id_factory()
111-
return self._invocation_id
99+
return await self.invocation_id_future.get()
112100

113101
class ServerSendHandle(SendHandle):
114102
"""This class implements the send API"""
115-
_invocation_id: typing.Optional[str]
116103

117-
def __init__(self, context, handle: int) -> None:
104+
def __init__(self, context: "ServerInvocationContext", handle: int) -> None:
118105
super().__init__()
119-
self.handle = handle
120-
self.context = context
121-
self._invocation_id = None
106+
107+
async def coro():
108+
if not context.vm.is_completed(handle):
109+
await context.create_poll_or_cancel_coroutine([handle])
110+
return context.must_take_notification(handle)
111+
112+
self.future = LazyFuture(coro)
122113

123114
async def invocation_id(self) -> str:
124115
"""Get the invocation id."""
125-
if self._invocation_id is not None:
126-
return self._invocation_id
127-
res = await self.context.create_poll_or_cancel_coroutine(self.handle)
128-
self._invocation_id = res
129-
return res
116+
return await self.future
130117

131118
async def async_value(n: Callable[[], T]) -> T:
132119
"""convert a simple value to a coroutine."""
@@ -334,6 +321,7 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
334321
continue
335322
if isinstance(do_progress_response, DoProgressExecuteRun):
336323
fn = self.run_coros_to_execute[do_progress_response.handle]
324+
del self.run_coros_to_execute[do_progress_response.handle]
337325
assert fn is not None
338326

339327
async def wrapper(f):
@@ -346,11 +334,12 @@ async def wrapper(f):
346334
if isinstance(do_progress_response, DoWaitPendingRun):
347335
await self.sync_point.wait()
348336

349-
def create_df(self, handle: int, serde: Serde[T] | None = None) -> ServerDurableFuture[T]:
337+
def create_future(self, handle: int, serde: Serde[T] | None = None) -> ServerDurableFuture[T]:
350338
"""Create a durable future."""
351339

352340
async def transform():
353-
await self.create_poll_or_cancel_coroutine([handle])
341+
if not self.vm.is_completed(handle):
342+
await self.create_poll_or_cancel_coroutine([handle])
354343
res = self.must_take_notification(handle)
355344
if res is None or serde is None:
356345
return res
@@ -359,30 +348,31 @@ async def transform():
359348
return ServerDurableFuture(self, handle, transform)
360349

361350

362-
363-
def create_call_df(self, handle: int, invocation_id_handle: int, serde: Serde[T] | None = None) -> ServerCallDurableFuture[T]:
351+
def create_call_future(self, handle: int, invocation_id_handle: int, serde: Serde[T] | None = None) -> ServerCallDurableFuture[T]:
364352
"""Create a durable future."""
365353

366354
async def transform():
367-
await self.create_poll_or_cancel_coroutine([handle])
355+
if not self.vm.is_completed(handle):
356+
await self.create_poll_or_cancel_coroutine([handle])
368357
res = self.must_take_notification(handle)
369358
if res is None or serde is None:
370359
return res
371360
return serde.deserialize(res)
372361

373362
async def inv_id_factory():
374-
await self.create_poll_or_cancel_coroutine([invocation_id_handle])
363+
if not self.vm.is_completed(invocation_id_handle):
364+
await self.create_poll_or_cancel_coroutine([invocation_id_handle])
375365
return self.must_take_notification(invocation_id_handle)
376366

377-
return ServerCallDurableFuture(self, handle, transform, invocation_id_handle, inv_id_factory)
367+
return ServerCallDurableFuture(self, handle, transform, inv_id_factory)
378368

379369

380370
def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]]:
381371
handle = self.vm.sys_get_state(name)
382-
return self.create_df(handle, serde) # type: ignore
372+
return self.create_future(handle, serde) # type: ignore
383373

384374
def state_keys(self) -> Awaitable[List[str]]:
385-
return self.create_df(self.vm.sys_get_state_keys()) # type: ignore
375+
return self.create_future(self.vm.sys_get_state_keys()) # type: ignore
386376

387377
def set(self, name: str, value: T, serde: Serde[T] = JsonSerde()) -> None:
388378
"""Set the value associated with the given name."""
@@ -446,13 +436,13 @@ def run(self,
446436
self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, action, serde, max_attempts, max_retry_duration)
447437

448438
# Prepare response coroutine
449-
return self.create_df(handle, serde) # type: ignore
439+
return self.create_future(handle, serde) # type: ignore
450440

451441

452442
def sleep(self, delta: timedelta) -> RestateDurableFuture[None]:
453443
# convert timedelta to milliseconds
454444
millis = int(delta.total_seconds() * 1000)
455-
return self.create_df(self.vm.sys_sleep(millis)) # type: ignore
445+
return self.create_future(self.vm.sys_sleep(millis)) # type: ignore
456446

457447
def do_call(self,
458448
tpe: Callable[[Any, I], Awaitable[O]],
@@ -501,7 +491,7 @@ def do_raw_call(self,
501491
idempotency_key=idempotency_key,
502492
headers=headers)
503493

504-
return self.create_call_df(handle=handle.result_handle,
494+
return self.create_call_future(handle=handle.result_handle,
505495
invocation_id_handle=handle.invocation_id_handle,
506496
serde=output_serde)
507497

@@ -582,7 +572,7 @@ def awakeable(self,
582572
serde: typing.Optional[Serde[I]] = JsonSerde()) -> typing.Tuple[str, RestateDurableFuture[Any]]:
583573
assert serde is not None
584574
name, handle = self.vm.sys_awakeable()
585-
return name, self.create_df(handle, serde)
575+
return name, self.create_future(handle, serde)
586576

587577
def resolve_awakeable(self,
588578
name: str,
@@ -613,4 +603,4 @@ def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -
613603
raise ValueError("invocation_id cannot be None")
614604
assert serde is not None
615605
handle = self.vm.attach_invocation(invocation_id)
616-
return self.create_df(handle, serde)
606+
return self.create_future(handle, serde)

0 commit comments

Comments
 (0)