17
17
"""This module contains the restate context implementation based on the server"""
18
18
19
19
import asyncio
20
+ import contextvars
20
21
from datetime import timedelta
21
22
import inspect
22
23
import functools
@@ -87,7 +88,6 @@ async def get(self) -> T:
87
88
"""Get the value of the future."""
88
89
if self .task is None :
89
90
self .task = asyncio .create_task (self .async_def ())
90
-
91
91
return await self .task
92
92
93
93
def __await__ (self ):
@@ -236,6 +236,33 @@ async def arrive(self):
236
236
async with self ._cond :
237
237
self ._cond .notify_all ()
238
238
239
+ class Tasks :
240
+ """
241
+ This class implements a list of tasks.
242
+ """
243
+
244
+ def __init__ (self ) -> None :
245
+ self .tasks : List [asyncio .Future ] = []
246
+
247
+ def add (self , task : asyncio .Future ):
248
+ """Add a task to the list."""
249
+ self .tasks .append (task )
250
+
251
+ def safe_remove (_ ):
252
+ """Remove the task from the list."""
253
+ try :
254
+ self .tasks .remove (task )
255
+ except ValueError :
256
+ pass
257
+
258
+ task .add_done_callback (safe_remove )
259
+
260
+ def cancel (self ):
261
+ """Cancel all tasks in the list."""
262
+ for task in self .tasks :
263
+ task .cancel ()
264
+ self .tasks .clear ()
265
+
239
266
# pylint: disable=R0902
240
267
class ServerInvocationContext (ObjectContext ):
241
268
"""This class implements the context for the restate framework based on the server."""
@@ -257,6 +284,7 @@ def __init__(self,
257
284
self .run_coros_to_execute : dict [int , Callable [[], Awaitable [None ]]] = {}
258
285
self .sync_point = SyncPoint ()
259
286
self .request_finished_event = asyncio .Event ()
287
+ self .tasks = Tasks ()
260
288
261
289
async def enter (self ):
262
290
"""Invoke the user code."""
@@ -320,6 +348,11 @@ async def leave(self):
320
348
def on_attempt_finished (self ):
321
349
"""Notify the attempt finished event."""
322
350
self .request_finished_event .set ()
351
+ try :
352
+ self .tasks .cancel ()
353
+ except asyncio .CancelledError :
354
+ # ignore the cancelled error
355
+ pass
323
356
324
357
325
358
async def receive_and_notify_input (self ):
@@ -378,11 +411,14 @@ async def wrapper(f):
378
411
await self .take_and_send_output ()
379
412
await self .sync_point .arrive ()
380
413
381
- asyncio .create_task (wrapper (fn ))
414
+ task = asyncio .create_task (wrapper (fn ))
415
+ self .tasks .add (task )
382
416
continue
383
417
if isinstance (do_progress_response , DoWaitPendingRun ):
384
418
sync_task = asyncio .create_task (self .sync_point .wait ())
385
419
read_task = asyncio .create_task (self .receive_and_notify_input ())
420
+ self .tasks .add (sync_task )
421
+ self .tasks .add (read_task )
386
422
done , _ = await asyncio .wait ([sync_task , read_task ], return_when = asyncio .FIRST_COMPLETED )
387
423
if read_task in done :
388
424
_ = read_task .result () # rethrow any exception
@@ -459,6 +495,7 @@ def request(self) -> Request:
459
495
attempt_finished_event = ServerTeardownEvent (self .request_finished_event ),
460
496
)
461
497
498
+ # pylint: disable=R0914
462
499
async def create_run_coroutine (self ,
463
500
handle : int ,
464
501
action : Callable [[], T ] | Callable [[], Awaitable [T ]],
@@ -471,7 +508,12 @@ async def create_run_coroutine(self,
471
508
if inspect .iscoroutinefunction (action ):
472
509
action_result : T = await action () # type: ignore
473
510
else :
474
- action_result = typing .cast (T , await asyncio .to_thread (action ))
511
+ loop = asyncio .get_running_loop ()
512
+ ctx = contextvars .copy_context ()
513
+ func_call = functools .partial (ctx .run , action )
514
+ action_result_future = loop .run_in_executor (None , func_call )
515
+ self .tasks .add (action_result_future )
516
+ action_result = typing .cast (T , await action_result_future )
475
517
476
518
buffer = serde .serialize (action_result )
477
519
self .vm .propose_run_completion_success (handle , buffer )
0 commit comments