Skip to content

Commit 1c5f4dc

Browse files
authored
Cancel any spawned tasks when the attempt is over (#84)
This commit, automatically cancels any asyncio.Tasks that were spawn during an attempt, after the attempt was completed.
1 parent 7bcd910 commit 1c5f4dc

File tree

1 file changed

+45
-3
lines changed

1 file changed

+45
-3
lines changed

python/restate/server_context.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""This module contains the restate context implementation based on the server"""
1818

1919
import asyncio
20+
import contextvars
2021
from datetime import timedelta
2122
import inspect
2223
import functools
@@ -87,7 +88,6 @@ async def get(self) -> T:
8788
"""Get the value of the future."""
8889
if self.task is None:
8990
self.task = asyncio.create_task(self.async_def())
90-
9191
return await self.task
9292

9393
def __await__(self):
@@ -236,6 +236,33 @@ async def arrive(self):
236236
async with self._cond:
237237
self._cond.notify_all()
238238

239+
class Tasks:
240+
"""
241+
This class implements a list of tasks.
242+
"""
243+
244+
def __init__(self) -> None:
245+
self.tasks: List[asyncio.Future] = []
246+
247+
def add(self, task: asyncio.Future):
248+
"""Add a task to the list."""
249+
self.tasks.append(task)
250+
251+
def safe_remove(_):
252+
"""Remove the task from the list."""
253+
try:
254+
self.tasks.remove(task)
255+
except ValueError:
256+
pass
257+
258+
task.add_done_callback(safe_remove)
259+
260+
def cancel(self):
261+
"""Cancel all tasks in the list."""
262+
for task in self.tasks:
263+
task.cancel()
264+
self.tasks.clear()
265+
239266
# pylint: disable=R0902
240267
class ServerInvocationContext(ObjectContext):
241268
"""This class implements the context for the restate framework based on the server."""
@@ -257,6 +284,7 @@ def __init__(self,
257284
self.run_coros_to_execute: dict[int, Callable[[], Awaitable[None]]] = {}
258285
self.sync_point = SyncPoint()
259286
self.request_finished_event = asyncio.Event()
287+
self.tasks = Tasks()
260288

261289
async def enter(self):
262290
"""Invoke the user code."""
@@ -320,6 +348,11 @@ async def leave(self):
320348
def on_attempt_finished(self):
321349
"""Notify the attempt finished event."""
322350
self.request_finished_event.set()
351+
try:
352+
self.tasks.cancel()
353+
except asyncio.CancelledError:
354+
# ignore the cancelled error
355+
pass
323356

324357

325358
async def receive_and_notify_input(self):
@@ -378,11 +411,14 @@ async def wrapper(f):
378411
await self.take_and_send_output()
379412
await self.sync_point.arrive()
380413

381-
asyncio.create_task(wrapper(fn))
414+
task = asyncio.create_task(wrapper(fn))
415+
self.tasks.add(task)
382416
continue
383417
if isinstance(do_progress_response, DoWaitPendingRun):
384418
sync_task = asyncio.create_task(self.sync_point.wait())
385419
read_task = asyncio.create_task(self.receive_and_notify_input())
420+
self.tasks.add(sync_task)
421+
self.tasks.add(read_task)
386422
done, _ = await asyncio.wait([sync_task, read_task], return_when=asyncio.FIRST_COMPLETED)
387423
if read_task in done:
388424
_ = read_task.result() # rethrow any exception
@@ -459,6 +495,7 @@ def request(self) -> Request:
459495
attempt_finished_event=ServerTeardownEvent(self.request_finished_event),
460496
)
461497

498+
# pylint: disable=R0914
462499
async def create_run_coroutine(self,
463500
handle: int,
464501
action: Callable[[], T] | Callable[[], Awaitable[T]],
@@ -471,7 +508,12 @@ async def create_run_coroutine(self,
471508
if inspect.iscoroutinefunction(action):
472509
action_result: T = await action() # type: ignore
473510
else:
474-
action_result = typing.cast(T, await asyncio.to_thread(action))
511+
loop = asyncio.get_running_loop()
512+
ctx = contextvars.copy_context()
513+
func_call = functools.partial(ctx.run, action)
514+
action_result_future = loop.run_in_executor(None, func_call)
515+
self.tasks.add(action_result_future)
516+
action_result = typing.cast(T, await action_result_future)
475517

476518
buffer = serde.serialize(action_result)
477519
self.vm.propose_run_completion_success(handle, buffer)

0 commit comments

Comments
 (0)