Skip to content

Commit 2ba1360

Browse files
committed
Unify read from input with side effect completion
1 parent 1c5f4dc commit 2ba1360

File tree

3 files changed

+28
-26
lines changed

3 files changed

+28
-26
lines changed

.github/workflows/integration.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,4 @@ jobs:
109109
serviceContainerImage: "restatedev/test-services-python"
110110
exclusionsFile: "test-services/exclusions.yaml"
111111
testArtifactOutput: "sdk-python-integration-test-report"
112+
serviceContainerEnvFile: "test-services/.env"

python/restate/server_context.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -223,45 +223,47 @@ class SyncPoint:
223223
This class implements a synchronization point.
224224
"""
225225

226-
def __init__(self):
227-
self._cond = asyncio.Condition()
226+
def __init__(self) -> None:
227+
self.cond: asyncio.Event | None = None
228228

229-
async def wait(self):
229+
def awaiter(self):
230230
"""Wait for the sync point."""
231-
async with self._cond:
232-
await self._cond.wait()
231+
if self.cond is None:
232+
self.cond = asyncio.Event()
233+
return self.cond.wait()
233234

234235
async def arrive(self):
235-
"""Arrive at the sync point."""
236-
async with self._cond:
237-
self._cond.notify_all()
236+
"""arrive at the sync point."""
237+
if self.cond is not None:
238+
self.cond.set()
238239

239240
class Tasks:
240241
"""
241242
This class implements a list of tasks.
242243
"""
243244

244245
def __init__(self) -> None:
245-
self.tasks: List[asyncio.Future] = []
246+
self.tasks: set[asyncio.Future] = set()
246247

247248
def add(self, task: asyncio.Future):
248249
"""Add a task to the list."""
249-
self.tasks.append(task)
250+
self.tasks.add(task)
250251

251252
def safe_remove(_):
252253
"""Remove the task from the list."""
253254
try:
254255
self.tasks.remove(task)
255-
except ValueError:
256+
except KeyError:
256257
pass
257258

258259
task.add_done_callback(safe_remove)
259260

260261
def cancel(self):
261262
"""Cancel all tasks in the list."""
262-
for task in self.tasks:
263-
task.cancel()
263+
to_cancel = list(self.tasks)
264264
self.tasks.clear()
265+
for task in to_cancel:
266+
task.cancel()
265267

266268
# pylint: disable=R0902
267269
class ServerInvocationContext(ObjectContext):
@@ -358,13 +360,13 @@ def on_attempt_finished(self):
358360
async def receive_and_notify_input(self):
359361
"""Receive input from the state machine."""
360362
chunk = await self.receive()
361-
if chunk.get('type') == 'http.request':
363+
if chunk.get('type') == 'http.disconnect':
364+
raise DisconnectedException()
365+
if chunk.get('body', None) is not None:
362366
assert isinstance(chunk['body'], bytes)
363367
self.vm.notify_input(chunk['body'])
364368
if not chunk.get('more_body', False):
365369
self.vm.notify_input_closed()
366-
if chunk.get('type') == 'http.disconnect':
367-
raise DisconnectedException()
368370

369371
async def take_and_send_output(self):
370372
"""Take output from state machine and send it"""
@@ -398,9 +400,6 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
398400
return
399401
if isinstance(do_progress_response, DoProgressCancelSignalReceived):
400402
raise TerminalError("cancelled", 409)
401-
if isinstance(do_progress_response, DoProgressReadFromInput):
402-
await self.receive_and_notify_input()
403-
continue
404403
if isinstance(do_progress_response, DoProgressExecuteRun):
405404
fn = self.run_coros_to_execute[do_progress_response.handle]
406405
del self.run_coros_to_execute[do_progress_response.handle]
@@ -414,17 +413,16 @@ async def wrapper(f):
414413
task = asyncio.create_task(wrapper(fn))
415414
self.tasks.add(task)
416415
continue
417-
if isinstance(do_progress_response, DoWaitPendingRun):
418-
sync_task = asyncio.create_task(self.sync_point.wait())
419-
read_task = asyncio.create_task(self.receive_and_notify_input())
416+
if isinstance(do_progress_response, (DoWaitPendingRun, DoProgressReadFromInput)):
417+
sync_task = asyncio.create_task(self.sync_point.awaiter())
420418
self.tasks.add(sync_task)
419+
420+
read_task = asyncio.create_task(self.receive_and_notify_input())
421421
self.tasks.add(read_task)
422+
422423
done, _ = await asyncio.wait([sync_task, read_task], return_when=asyncio.FIRST_COMPLETED)
423424
if read_task in done:
424-
_ = read_task.result() # rethrow any exception
425-
if sync_task in done:
426-
continue
427-
425+
_ = read_task.result() # propagate exception
428426

429427
def _create_fetch_result_coroutine(self, handle: int, serde: Serde[T] | None = None):
430428
"""Create a coroutine that fetches a result from a notification handle."""
@@ -520,6 +518,8 @@ async def create_run_coroutine(self,
520518
except TerminalError as t:
521519
failure = Failure(code=t.status_code, message=t.message)
522520
self.vm.propose_run_completion_failure(handle, failure)
521+
except asyncio.CancelledError as e:
522+
raise e from None
523523
# pylint: disable=W0718
524524
except Exception as e:
525525
if max_attempts is None and max_retry_duration is None:

test-services/.env

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
RESTATE_CORE_LOG=trace

0 commit comments

Comments
 (0)