Skip to content

Auto cancel tasks when an attempt is over #84

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""This module contains the restate context implementation based on the server"""

import asyncio
import contextvars
from datetime import timedelta
import inspect
import functools
Expand Down Expand Up @@ -87,7 +88,6 @@ async def get(self) -> T:
"""Get the value of the future."""
if self.task is None:
self.task = asyncio.create_task(self.async_def())

return await self.task

def __await__(self):
Expand Down Expand Up @@ -236,6 +236,33 @@ async def arrive(self):
async with self._cond:
self._cond.notify_all()

class Tasks:
"""
This class implements a list of tasks.
"""

def __init__(self) -> None:
self.tasks: List[asyncio.Future] = []

def add(self, task: asyncio.Future):
"""Add a task to the list."""
self.tasks.append(task)

def safe_remove(_):
"""Remove the task from the list."""
try:
self.tasks.remove(task)
except ValueError:
pass

task.add_done_callback(safe_remove)

def cancel(self):
"""Cancel all tasks in the list."""
for task in self.tasks:
task.cancel()
self.tasks.clear()

# pylint: disable=R0902
class ServerInvocationContext(ObjectContext):
"""This class implements the context for the restate framework based on the server."""
Expand All @@ -257,6 +284,7 @@ def __init__(self,
self.run_coros_to_execute: dict[int, Callable[[], Awaitable[None]]] = {}
self.sync_point = SyncPoint()
self.request_finished_event = asyncio.Event()
self.tasks = Tasks()

async def enter(self):
"""Invoke the user code."""
Expand Down Expand Up @@ -320,6 +348,11 @@ async def leave(self):
def on_attempt_finished(self):
"""Notify the attempt finished event."""
self.request_finished_event.set()
try:
self.tasks.cancel()
except asyncio.CancelledError:
# ignore the cancelled error
pass


async def receive_and_notify_input(self):
Expand Down Expand Up @@ -378,11 +411,14 @@ async def wrapper(f):
await self.take_and_send_output()
await self.sync_point.arrive()

asyncio.create_task(wrapper(fn))
task = asyncio.create_task(wrapper(fn))
self.tasks.add(task)
continue
if isinstance(do_progress_response, DoWaitPendingRun):
sync_task = asyncio.create_task(self.sync_point.wait())
read_task = asyncio.create_task(self.receive_and_notify_input())
self.tasks.add(sync_task)
self.tasks.add(read_task)
done, _ = await asyncio.wait([sync_task, read_task], return_when=asyncio.FIRST_COMPLETED)
if read_task in done:
_ = read_task.result() # rethrow any exception
Expand Down Expand Up @@ -459,6 +495,7 @@ def request(self) -> Request:
attempt_finished_event=ServerTeardownEvent(self.request_finished_event),
)

# pylint: disable=R0914
async def create_run_coroutine(self,
handle: int,
action: Callable[[], T] | Callable[[], Awaitable[T]],
Expand All @@ -471,7 +508,12 @@ async def create_run_coroutine(self,
if inspect.iscoroutinefunction(action):
action_result: T = await action() # type: ignore
else:
action_result = typing.cast(T, await asyncio.to_thread(action))
loop = asyncio.get_running_loop()
ctx = contextvars.copy_context()
func_call = functools.partial(ctx.run, action)
action_result_future = loop.run_in_executor(None, func_call)
self.tasks.add(action_result_future)
action_result = typing.cast(T, await action_result_future)

buffer = serde.serialize(action_result)
self.vm.propose_run_completion_success(handle, buffer)
Expand Down