@@ -225,15 +225,34 @@ class SyncPoint:
225
225
226
226
def __init__ (self ):
227
227
self ._cond = asyncio .Condition ()
228
+ self ._running = set [asyncio .Task ]()
228
229
229
- async def wait (self ):
230
+ async def wait (self , snapshot : set [ asyncio . Task ] ):
230
231
"""Wait for the sync point."""
232
+
233
+ def wait_fn ():
234
+ """
235
+ Check if any task had finished.
236
+ """
237
+ assert self ._cond .locked ()
238
+ for task in snapshot :
239
+ if task not in self ._running :
240
+ return True
241
+ return False
242
+
231
243
async with self ._cond :
232
- await self ._cond .wait ( )
244
+ await self ._cond .wait_for ( wait_fn )
233
245
234
- async def arrive (self ):
246
+ async def enter (self ):
235
247
"""Arrive at the sync point."""
236
248
async with self ._cond :
249
+ self ._running .add (asyncio .current_task ())
250
+ await self ._cond .notify_all ()
251
+
252
+ async def leave (self ):
253
+ """leave the sync point."""
254
+ async with self ._cond :
255
+ self ._running .remove (asyncio .current_task ())
237
256
self ._cond .notify_all ()
238
257
239
258
class Tasks :
@@ -398,33 +417,31 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
398
417
return
399
418
if isinstance (do_progress_response , DoProgressCancelSignalReceived ):
400
419
raise TerminalError ("cancelled" , 409 )
401
- if isinstance (do_progress_response , DoProgressReadFromInput ):
402
- await self .receive_and_notify_input ()
403
- continue
404
420
if isinstance (do_progress_response , DoProgressExecuteRun ):
405
421
fn = self .run_coros_to_execute [do_progress_response .handle ]
406
422
del self .run_coros_to_execute [do_progress_response .handle ]
407
423
assert fn is not None
408
424
409
425
async def wrapper (f ):
426
+ await self .sync_point .enter ()
410
427
await f ()
411
428
await self .take_and_send_output ()
412
- await self .sync_point .arrive ()
429
+ await self .sync_point .leave ()
413
430
414
431
task = asyncio .create_task (wrapper (fn ))
415
432
self .tasks .add (task )
433
+
416
434
continue
417
- if isinstance (do_progress_response , DoWaitPendingRun ):
418
- sync_task = asyncio .create_task (self .sync_point .wait ())
435
+ if isinstance (do_progress_response , (DoWaitPendingRun , DoProgressReadFromInput )):
436
+ snapshot = self .sync_point ._running .copy ()
437
+ sync_task = asyncio .create_task (self .sync_point .wait (snapshot ))
419
438
read_task = asyncio .create_task (self .receive_and_notify_input ())
420
439
self .tasks .add (sync_task )
421
440
self .tasks .add (read_task )
422
441
done , _ = await asyncio .wait ([sync_task , read_task ], return_when = asyncio .FIRST_COMPLETED )
423
442
if read_task in done :
424
443
_ = read_task .result () # rethrow any exception
425
- if sync_task in done :
426
- continue
427
-
444
+
428
445
429
446
def _create_fetch_result_coroutine (self , handle : int , serde : Serde [T ] | None = None ):
430
447
"""Create a coroutine that fetches a result from a notification handle."""
@@ -520,6 +537,8 @@ async def create_run_coroutine(self,
520
537
except TerminalError as t :
521
538
failure = Failure (code = t .status_code , message = t .message )
522
539
self .vm .propose_run_completion_failure (handle , failure )
540
+ except asyncio .CancelledError as e :
541
+ raise e from None
523
542
# pylint: disable=W0718
524
543
except Exception as e :
525
544
if max_attempts is None and max_retry_duration is None :
0 commit comments