From 624b3a6d28cf80cdf4d133b5b10b1a88f3e94da3 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 8 Apr 2025 13:20:58 +0000 Subject: [PATCH] Add a shortcut f.cancel_invocation() --- examples/concurrent_greeter.py | 58 +++++++++++++++++--------------- python/restate/context.py | 22 ++++++++++++ python/restate/server_context.py | 16 +++++++++ 3 files changed, 68 insertions(+), 28 deletions(-) diff --git a/examples/concurrent_greeter.py b/examples/concurrent_greeter.py index a7947b7..eed1ef3 100644 --- a/examples/concurrent_greeter.py +++ b/examples/concurrent_greeter.py @@ -13,48 +13,50 @@ # pylint: disable=W0613 # pylint: disable=C0115 # pylint: disable=R0903 +# pylint: disable=C0301 -from datetime import timedelta +import typing from pydantic import BaseModel -from restate import Service, Context -from restate import wait_completed, RestateDurableSleepFuture, RestateDurableCallFuture - -from greeter import greet as g +from restate import wait_completed, Service, Context # models class GreetingRequest(BaseModel): name: str class Greeting(BaseModel): - message: str + messages: typing.List[str] -# service +class Message(BaseModel): + role: str + content: str concurrent_greeter = Service("concurrent_greeter") - @concurrent_greeter.handler() async def greet(ctx: Context, req: GreetingRequest) -> Greeting: - g1 = ctx.service_call(g, arg="1") - g2 = ctx.service_call(g, arg="2") - g3 = ctx.sleep(timedelta(milliseconds=100)) - - done, pending = await wait_completed(g1, g2, g3) - - for f in done: - if isinstance(f, RestateDurableSleepFuture): - print("Timeout :(x", flush=True) - elif isinstance(f, RestateDurableCallFuture): - # the result should be ready. - print(await f) - # - # let's cancel the pending calls then - # + claude_sonnet.as_handler(ctx) + claude = ctx.service_call(claude_sonnet, arg=Message(role="user", content=f"please greet {req.name}")) + openai = ctx.service_call(open_ai, arg=Message(role="user", content=f"please greet {req.name}")) + + pending, done = await wait_completed(claude, openai) + + # collect the completed greetings + greetings = [await f for f in done] + + # cancel the pending calls for f in pending: - if isinstance(f, RestateDurableCallFuture): - inv = await f.invocation_id() - print(f"Canceling {inv}", flush=True) - ctx.cancel_invocation(inv) + await f.cancel_invocation() # type: ignore + + return Greeting(messages=greetings) - return Greeting(message=f"Hello {req.name}!") + +# not really interesting, just for this demo: + +@concurrent_greeter.handler() +async def claude_sonnet(ctx: Context, req: Message) -> str: + return f"Bonjour {req.content[13:]}!" + +@concurrent_greeter.handler() +async def open_ai(ctx: Context, req: Message) -> str: + return f"Hello {req.content[13:]}!" diff --git a/python/restate/context.py b/python/restate/context.py index 54580bc..17c5fb9 100644 --- a/python/restate/context.py +++ b/python/restate/context.py @@ -51,6 +51,17 @@ async def invocation_id(self) -> str: Returns the invocation id of the call. """ + @abc.abstractmethod + async def cancel_invocation(self) -> None: + """ + Cancels the invocation. + + Just a utility shortcut to: + .. code-block:: python + + await ctx.cancel_invocation(await f.invocation_id()) + """ + class RestateDurableSleepFuture(RestateDurableFuture[None]): """ @@ -136,6 +147,17 @@ async def invocation_id(self) -> str: Returns the invocation id of the send operation. """ + @abc.abstractmethod + async def cancel_invocation(self) -> None: + """ + Cancels the invocation. + + Just a utility shortcut to: + .. code-block:: python + + await ctx.cancel_invocation(await f.invocation_id()) + """ + class Context(abc.ABC): """ Represents the context of the current invocation. diff --git a/python/restate/server_context.py b/python/restate/server_context.py index 6231eaf..9b24eaa 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -106,11 +106,22 @@ async def invocation_id(self) -> str: """Get the invocation id.""" return await self.invocation_id_future.get() + async def cancel_invocation(self) -> None: + """ + Cancels the invocation. + + Just a utility shortcut to: + .. code-block:: python + + await ctx.cancel_invocation(await f.invocation_id()) + """ + class ServerSendHandle(SendHandle): """This class implements the send API""" def __init__(self, context: "ServerInvocationContext", handle: int) -> None: super().__init__() + self.context = context async def coro(): if not context.vm.is_completed(handle): @@ -123,6 +134,11 @@ async def invocation_id(self) -> str: """Get the invocation id.""" return await self.future + async def cancel_invocation(self) -> None: + """Cancel the invocation.""" + invocation_id = await self.invocation_id() + self.context.cancel_invocation(invocation_id) + async def async_value(n: Callable[[], T]) -> T: """convert a simple value to a coroutine.""" return n()