@@ -143,6 +143,8 @@ async def cancel_invocation(self) -> None:
143
143
144
144
await ctx.cancel_invocation(await f.invocation_id())
145
145
"""
146
+ inv = await self .invocation_id ()
147
+ await self .context .cancel_invocation (inv )
146
148
147
149
class ServerSendHandle (SendHandle ):
148
150
"""This class implements the send API"""
@@ -223,45 +225,47 @@ class SyncPoint:
223
225
This class implements a synchronization point.
224
226
"""
225
227
226
- def __init__ (self ):
227
- self ._cond = asyncio .Condition ()
228
+ def __init__ (self ) -> None :
229
+ self .cond : asyncio .Event | None = None
228
230
229
- async def wait (self ):
231
+ def awaiter (self ):
230
232
"""Wait for the sync point."""
231
- async with self ._cond :
232
- await self ._cond .wait ()
233
+ if self .cond is None :
234
+ self .cond = asyncio .Event ()
235
+ return self .cond .wait ()
233
236
234
237
async def arrive (self ):
235
- """Arrive at the sync point."""
236
- async with self ._cond :
237
- self ._cond . notify_all ()
238
+ """arrive at the sync point."""
239
+ if self .cond is not None :
240
+ self .cond . set ()
238
241
239
242
class Tasks :
240
243
"""
241
244
This class implements a list of tasks.
242
245
"""
243
246
244
247
def __init__ (self ) -> None :
245
- self .tasks : List [asyncio .Future ] = []
248
+ self .tasks : set [asyncio .Future ] = set ()
246
249
247
250
def add (self , task : asyncio .Future ):
248
251
"""Add a task to the list."""
249
- self .tasks .append (task )
252
+ self .tasks .add (task )
250
253
251
254
def safe_remove (_ ):
252
255
"""Remove the task from the list."""
253
256
try :
254
257
self .tasks .remove (task )
255
- except ValueError :
258
+ except KeyError :
256
259
pass
257
260
258
261
task .add_done_callback (safe_remove )
259
262
260
263
def cancel (self ):
261
264
"""Cancel all tasks in the list."""
262
- for task in self .tasks :
263
- task .cancel ()
265
+ to_cancel = list (self .tasks )
264
266
self .tasks .clear ()
267
+ for task in to_cancel :
268
+ task .cancel ()
265
269
266
270
# pylint: disable=R0902
267
271
class ServerInvocationContext (ObjectContext ):
@@ -358,13 +362,13 @@ def on_attempt_finished(self):
358
362
async def receive_and_notify_input (self ):
359
363
"""Receive input from the state machine."""
360
364
chunk = await self .receive ()
361
- if chunk .get ('type' ) == 'http.request' :
365
+ if chunk .get ('type' ) == 'http.disconnect' :
366
+ raise DisconnectedException ()
367
+ if chunk .get ('body' , None ) is not None :
362
368
assert isinstance (chunk ['body' ], bytes )
363
369
self .vm .notify_input (chunk ['body' ])
364
370
if not chunk .get ('more_body' , False ):
365
371
self .vm .notify_input_closed ()
366
- if chunk .get ('type' ) == 'http.disconnect' :
367
- raise DisconnectedException ()
368
372
369
373
async def take_and_send_output (self ):
370
374
"""Take output from state machine and send it"""
@@ -398,9 +402,6 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
398
402
return
399
403
if isinstance (do_progress_response , DoProgressCancelSignalReceived ):
400
404
raise TerminalError ("cancelled" , 409 )
401
- if isinstance (do_progress_response , DoProgressReadFromInput ):
402
- await self .receive_and_notify_input ()
403
- continue
404
405
if isinstance (do_progress_response , DoProgressExecuteRun ):
405
406
fn = self .run_coros_to_execute [do_progress_response .handle ]
406
407
del self .run_coros_to_execute [do_progress_response .handle ]
@@ -414,17 +415,16 @@ async def wrapper(f):
414
415
task = asyncio .create_task (wrapper (fn ))
415
416
self .tasks .add (task )
416
417
continue
417
- if isinstance (do_progress_response , DoWaitPendingRun ):
418
- sync_task = asyncio .create_task (self .sync_point .wait ())
419
- read_task = asyncio .create_task (self .receive_and_notify_input ())
418
+ if isinstance (do_progress_response , (DoWaitPendingRun , DoProgressReadFromInput )):
419
+ sync_task = asyncio .create_task (self .sync_point .awaiter ())
420
420
self .tasks .add (sync_task )
421
+
422
+ read_task = asyncio .create_task (self .receive_and_notify_input ())
421
423
self .tasks .add (read_task )
424
+
422
425
done , _ = await asyncio .wait ([sync_task , read_task ], return_when = asyncio .FIRST_COMPLETED )
423
426
if read_task in done :
424
- _ = read_task .result () # rethrow any exception
425
- if sync_task in done :
426
- continue
427
-
427
+ _ = read_task .result () # propagate exception
428
428
429
429
def _create_fetch_result_coroutine (self , handle : int , serde : Serde [T ] | None = None ):
430
430
"""Create a coroutine that fetches a result from a notification handle."""
@@ -520,6 +520,8 @@ async def create_run_coroutine(self,
520
520
except TerminalError as t :
521
521
failure = Failure (code = t .status_code , message = t .message )
522
522
self .vm .propose_run_completion_failure (handle , failure )
523
+ except asyncio .CancelledError as e :
524
+ raise e from None
523
525
# pylint: disable=W0718
524
526
except Exception as e :
525
527
if max_attempts is None and max_retry_duration is None :
0 commit comments