Skip to content

Commit 55dd26b

Browse files
authored
Unify read from input with side effect completion (#86)
1 parent 1c5f4dc commit 55dd26b

File tree

3 files changed

+30
-26
lines changed

3 files changed

+30
-26
lines changed

.github/workflows/integration.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,4 @@ jobs:
109109
serviceContainerImage: "restatedev/test-services-python"
110110
exclusionsFile: "test-services/exclusions.yaml"
111111
testArtifactOutput: "sdk-python-integration-test-report"
112+
serviceContainerEnvFile: "test-services/.env"

python/restate/server_context.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ async def cancel_invocation(self) -> None:
143143
144144
await ctx.cancel_invocation(await f.invocation_id())
145145
"""
146+
inv = await self.invocation_id()
147+
await self.context.cancel_invocation(inv)
146148

147149
class ServerSendHandle(SendHandle):
148150
"""This class implements the send API"""
@@ -223,45 +225,47 @@ class SyncPoint:
223225
This class implements a synchronization point.
224226
"""
225227

226-
def __init__(self):
227-
self._cond = asyncio.Condition()
228+
def __init__(self) -> None:
229+
self.cond: asyncio.Event | None = None
228230

229-
async def wait(self):
231+
def awaiter(self):
230232
"""Wait for the sync point."""
231-
async with self._cond:
232-
await self._cond.wait()
233+
if self.cond is None:
234+
self.cond = asyncio.Event()
235+
return self.cond.wait()
233236

234237
async def arrive(self):
235-
"""Arrive at the sync point."""
236-
async with self._cond:
237-
self._cond.notify_all()
238+
"""arrive at the sync point."""
239+
if self.cond is not None:
240+
self.cond.set()
238241

239242
class Tasks:
240243
"""
241244
This class implements a list of tasks.
242245
"""
243246

244247
def __init__(self) -> None:
245-
self.tasks: List[asyncio.Future] = []
248+
self.tasks: set[asyncio.Future] = set()
246249

247250
def add(self, task: asyncio.Future):
248251
"""Add a task to the list."""
249-
self.tasks.append(task)
252+
self.tasks.add(task)
250253

251254
def safe_remove(_):
252255
"""Remove the task from the list."""
253256
try:
254257
self.tasks.remove(task)
255-
except ValueError:
258+
except KeyError:
256259
pass
257260

258261
task.add_done_callback(safe_remove)
259262

260263
def cancel(self):
261264
"""Cancel all tasks in the list."""
262-
for task in self.tasks:
263-
task.cancel()
265+
to_cancel = list(self.tasks)
264266
self.tasks.clear()
267+
for task in to_cancel:
268+
task.cancel()
265269

266270
# pylint: disable=R0902
267271
class ServerInvocationContext(ObjectContext):
@@ -358,13 +362,13 @@ def on_attempt_finished(self):
358362
async def receive_and_notify_input(self):
359363
"""Receive input from the state machine."""
360364
chunk = await self.receive()
361-
if chunk.get('type') == 'http.request':
365+
if chunk.get('type') == 'http.disconnect':
366+
raise DisconnectedException()
367+
if chunk.get('body', None) is not None:
362368
assert isinstance(chunk['body'], bytes)
363369
self.vm.notify_input(chunk['body'])
364370
if not chunk.get('more_body', False):
365371
self.vm.notify_input_closed()
366-
if chunk.get('type') == 'http.disconnect':
367-
raise DisconnectedException()
368372

369373
async def take_and_send_output(self):
370374
"""Take output from state machine and send it"""
@@ -398,9 +402,6 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
398402
return
399403
if isinstance(do_progress_response, DoProgressCancelSignalReceived):
400404
raise TerminalError("cancelled", 409)
401-
if isinstance(do_progress_response, DoProgressReadFromInput):
402-
await self.receive_and_notify_input()
403-
continue
404405
if isinstance(do_progress_response, DoProgressExecuteRun):
405406
fn = self.run_coros_to_execute[do_progress_response.handle]
406407
del self.run_coros_to_execute[do_progress_response.handle]
@@ -414,17 +415,16 @@ async def wrapper(f):
414415
task = asyncio.create_task(wrapper(fn))
415416
self.tasks.add(task)
416417
continue
417-
if isinstance(do_progress_response, DoWaitPendingRun):
418-
sync_task = asyncio.create_task(self.sync_point.wait())
419-
read_task = asyncio.create_task(self.receive_and_notify_input())
418+
if isinstance(do_progress_response, (DoWaitPendingRun, DoProgressReadFromInput)):
419+
sync_task = asyncio.create_task(self.sync_point.awaiter())
420420
self.tasks.add(sync_task)
421+
422+
read_task = asyncio.create_task(self.receive_and_notify_input())
421423
self.tasks.add(read_task)
424+
422425
done, _ = await asyncio.wait([sync_task, read_task], return_when=asyncio.FIRST_COMPLETED)
423426
if read_task in done:
424-
_ = read_task.result() # rethrow any exception
425-
if sync_task in done:
426-
continue
427-
427+
_ = read_task.result() # propagate exception
428428

429429
def _create_fetch_result_coroutine(self, handle: int, serde: Serde[T] | None = None):
430430
"""Create a coroutine that fetches a result from a notification handle."""
@@ -520,6 +520,8 @@ async def create_run_coroutine(self,
520520
except TerminalError as t:
521521
failure = Failure(code=t.status_code, message=t.message)
522522
self.vm.propose_run_completion_failure(handle, failure)
523+
except asyncio.CancelledError as e:
524+
raise e from None
523525
# pylint: disable=W0718
524526
except Exception as e:
525527
if max_attempts is None and max_retry_duration is None:

test-services/.env

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
RESTATE_CORE_LOG=trace

0 commit comments

Comments
 (0)