Skip to content

Commit 36b64c3

Browse files
committed
Unify cases
1 parent 1c5f4dc commit 36b64c3

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

python/restate/server_context.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,34 @@ class SyncPoint:
225225

226226
def __init__(self):
227227
self._cond = asyncio.Condition()
228+
self._running = set[asyncio.Task]()
228229

229-
async def wait(self):
230+
async def wait(self, snapshot: set[asyncio.Task]):
230231
"""Wait for the sync point."""
232+
233+
def wait_fn():
234+
"""
235+
Check if any task had finished.
236+
"""
237+
assert self._cond.locked()
238+
for task in snapshot:
239+
if task not in self._running:
240+
return True
241+
return False
242+
231243
async with self._cond:
232-
await self._cond.wait()
244+
await self._cond.wait_for(wait_fn)
233245

234-
async def arrive(self):
246+
async def enter(self):
235247
"""Arrive at the sync point."""
236248
async with self._cond:
249+
self._running.add(asyncio.current_task())
250+
await self._cond.notify_all()
251+
252+
async def leave(self):
253+
"""leave the sync point."""
254+
async with self._cond:
255+
self._running.remove(asyncio.current_task())
237256
self._cond.notify_all()
238257

239258
class Tasks:
@@ -398,33 +417,31 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
398417
return
399418
if isinstance(do_progress_response, DoProgressCancelSignalReceived):
400419
raise TerminalError("cancelled", 409)
401-
if isinstance(do_progress_response, DoProgressReadFromInput):
402-
await self.receive_and_notify_input()
403-
continue
404420
if isinstance(do_progress_response, DoProgressExecuteRun):
405421
fn = self.run_coros_to_execute[do_progress_response.handle]
406422
del self.run_coros_to_execute[do_progress_response.handle]
407423
assert fn is not None
408424

409425
async def wrapper(f):
426+
await self.sync_point.enter()
410427
await f()
411428
await self.take_and_send_output()
412-
await self.sync_point.arrive()
429+
await self.sync_point.leave()
413430

414431
task = asyncio.create_task(wrapper(fn))
415432
self.tasks.add(task)
433+
416434
continue
417-
if isinstance(do_progress_response, DoWaitPendingRun):
418-
sync_task = asyncio.create_task(self.sync_point.wait())
435+
if isinstance(do_progress_response, (DoWaitPendingRun, DoProgressReadFromInput)):
436+
snapshot = self.sync_point._running.copy()
437+
sync_task = asyncio.create_task(self.sync_point.wait(snapshot))
419438
read_task = asyncio.create_task(self.receive_and_notify_input())
420439
self.tasks.add(sync_task)
421440
self.tasks.add(read_task)
422441
done, _ = await asyncio.wait([sync_task, read_task], return_when=asyncio.FIRST_COMPLETED)
423442
if read_task in done:
424443
_ = read_task.result() # rethrow any exception
425-
if sync_task in done:
426-
continue
427-
444+
428445

429446
def _create_fetch_result_coroutine(self, handle: int, serde: Serde[T] | None = None):
430447
"""Create a coroutine that fetches a result from a notification handle."""
@@ -520,6 +537,8 @@ async def create_run_coroutine(self,
520537
except TerminalError as t:
521538
failure = Failure(code=t.status_code, message=t.message)
522539
self.vm.propose_run_completion_failure(handle, failure)
540+
except asyncio.CancelledError as e:
541+
raise e from None
523542
# pylint: disable=W0718
524543
except Exception as e:
525544
if max_attempts is None and max_retry_duration is None:

0 commit comments

Comments
 (0)