@@ -223,45 +223,47 @@ class SyncPoint:
223
223
This class implements a synchronization point.
224
224
"""
225
225
226
- def __init__ (self ):
227
- self ._cond = asyncio .Condition ()
226
+ def __init__ (self ) -> None :
227
+ self .cond : asyncio .Event | None = None
228
228
229
- async def wait (self ):
229
+ def awaiter (self ):
230
230
"""Wait for the sync point."""
231
- async with self ._cond :
232
- await self ._cond .wait ()
231
+ if self .cond is None :
232
+ self .cond = asyncio .Event ()
233
+ return self .cond .wait ()
233
234
234
235
async def arrive (self ):
235
- """Arrive at the sync point."""
236
- async with self ._cond :
237
- self ._cond . notify_all ()
236
+ """arrive at the sync point."""
237
+ if self .cond is not None :
238
+ self .cond . set ()
238
239
239
240
class Tasks :
240
241
"""
241
242
This class implements a list of tasks.
242
243
"""
243
244
244
245
def __init__ (self ) -> None :
245
- self .tasks : List [asyncio .Future ] = []
246
+ self .tasks : set [asyncio .Future ] = set ()
246
247
247
248
def add (self , task : asyncio .Future ):
248
249
"""Add a task to the list."""
249
- self .tasks .append (task )
250
+ self .tasks .add (task )
250
251
251
252
def safe_remove (_ ):
252
253
"""Remove the task from the list."""
253
254
try :
254
255
self .tasks .remove (task )
255
- except ValueError :
256
+ except KeyError :
256
257
pass
257
258
258
259
task .add_done_callback (safe_remove )
259
260
260
261
def cancel (self ):
261
262
"""Cancel all tasks in the list."""
262
- for task in self .tasks :
263
- task .cancel ()
263
+ to_cancel = list (self .tasks )
264
264
self .tasks .clear ()
265
+ for task in to_cancel :
266
+ task .cancel ()
265
267
266
268
# pylint: disable=R0902
267
269
class ServerInvocationContext (ObjectContext ):
@@ -358,13 +360,13 @@ def on_attempt_finished(self):
358
360
async def receive_and_notify_input (self ):
359
361
"""Receive input from the state machine."""
360
362
chunk = await self .receive ()
361
- if chunk .get ('type' ) == 'http.request' :
363
+ if chunk .get ('type' ) == 'http.disconnect' :
364
+ raise DisconnectedException ()
365
+ if chunk .get ('body' , None ) is not None :
362
366
assert isinstance (chunk ['body' ], bytes )
363
367
self .vm .notify_input (chunk ['body' ])
364
368
if not chunk .get ('more_body' , False ):
365
369
self .vm .notify_input_closed ()
366
- if chunk .get ('type' ) == 'http.disconnect' :
367
- raise DisconnectedException ()
368
370
369
371
async def take_and_send_output (self ):
370
372
"""Take output from state machine and send it"""
@@ -398,9 +400,6 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
398
400
return
399
401
if isinstance (do_progress_response , DoProgressCancelSignalReceived ):
400
402
raise TerminalError ("cancelled" , 409 )
401
- if isinstance (do_progress_response , DoProgressReadFromInput ):
402
- await self .receive_and_notify_input ()
403
- continue
404
403
if isinstance (do_progress_response , DoProgressExecuteRun ):
405
404
fn = self .run_coros_to_execute [do_progress_response .handle ]
406
405
del self .run_coros_to_execute [do_progress_response .handle ]
@@ -414,17 +413,16 @@ async def wrapper(f):
414
413
task = asyncio .create_task (wrapper (fn ))
415
414
self .tasks .add (task )
416
415
continue
417
- if isinstance (do_progress_response , DoWaitPendingRun ):
418
- sync_task = asyncio .create_task (self .sync_point .wait ())
419
- read_task = asyncio .create_task (self .receive_and_notify_input ())
416
+ if isinstance (do_progress_response , (DoWaitPendingRun , DoProgressReadFromInput )):
417
+ sync_task = asyncio .create_task (self .sync_point .awaiter ())
420
418
self .tasks .add (sync_task )
419
+
420
+ read_task = asyncio .create_task (self .receive_and_notify_input ())
421
421
self .tasks .add (read_task )
422
+
422
423
done , _ = await asyncio .wait ([sync_task , read_task ], return_when = asyncio .FIRST_COMPLETED )
423
424
if read_task in done :
424
- _ = read_task .result () # rethrow any exception
425
- if sync_task in done :
426
- continue
427
-
425
+ _ = read_task .result () # propagate exception
428
426
429
427
def _create_fetch_result_coroutine (self , handle : int , serde : Serde [T ] | None = None ):
430
428
"""Create a coroutine that fetches a result from a notification handle."""
@@ -520,6 +518,8 @@ async def create_run_coroutine(self,
520
518
except TerminalError as t :
521
519
failure = Failure (code = t .status_code , message = t .message )
522
520
self .vm .propose_run_completion_failure (handle , failure )
521
+ except asyncio .CancelledError as e :
522
+ raise e from None
523
523
# pylint: disable=W0718
524
524
except Exception as e :
525
525
if max_attempts is None and max_retry_duration is None :
0 commit comments