Skip to content

Add a shortcut f.cancel_invocation() #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 30 additions & 28 deletions examples/concurrent_greeter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,48 +13,50 @@
# pylint: disable=W0613
# pylint: disable=C0115
# pylint: disable=R0903
# pylint: disable=C0301

from datetime import timedelta
import typing

from pydantic import BaseModel
from restate import Service, Context
from restate import wait_completed, RestateDurableSleepFuture, RestateDurableCallFuture

from greeter import greet as g
from restate import wait_completed, Service, Context

# models
class GreetingRequest(BaseModel):
name: str

class Greeting(BaseModel):
message: str
messages: typing.List[str]

# service
class Message(BaseModel):
role: str
content: str

concurrent_greeter = Service("concurrent_greeter")


@concurrent_greeter.handler()
async def greet(ctx: Context, req: GreetingRequest) -> Greeting:
g1 = ctx.service_call(g, arg="1")
g2 = ctx.service_call(g, arg="2")
g3 = ctx.sleep(timedelta(milliseconds=100))

done, pending = await wait_completed(g1, g2, g3)

for f in done:
if isinstance(f, RestateDurableSleepFuture):
print("Timeout :(x", flush=True)
elif isinstance(f, RestateDurableCallFuture):
# the result should be ready.
print(await f)
#
# let's cancel the pending calls then
#
claude_sonnet.as_handler(ctx)
claude = ctx.service_call(claude_sonnet, arg=Message(role="user", content=f"please greet {req.name}"))
openai = ctx.service_call(open_ai, arg=Message(role="user", content=f"please greet {req.name}"))

pending, done = await wait_completed(claude, openai)

# collect the completed greetings
greetings = [await f for f in done]

# cancel the pending calls
for f in pending:
if isinstance(f, RestateDurableCallFuture):
inv = await f.invocation_id()
print(f"Canceling {inv}", flush=True)
ctx.cancel_invocation(inv)
await f.cancel_invocation() # type: ignore

return Greeting(messages=greetings)

return Greeting(message=f"Hello {req.name}!")

# not really interesting, just for this demo:

@concurrent_greeter.handler()
async def claude_sonnet(ctx: Context, req: Message) -> str:
return f"Bonjour {req.content[13:]}!"

@concurrent_greeter.handler()
async def open_ai(ctx: Context, req: Message) -> str:
return f"Hello {req.content[13:]}!"
22 changes: 22 additions & 0 deletions python/restate/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ async def invocation_id(self) -> str:
Returns the invocation id of the call.
"""

@abc.abstractmethod
async def cancel_invocation(self) -> None:
"""
Cancels the invocation.

Just a utility shortcut to:
.. code-block:: python

await ctx.cancel_invocation(await f.invocation_id())
"""


class RestateDurableSleepFuture(RestateDurableFuture[None]):
"""
Expand Down Expand Up @@ -136,6 +147,17 @@ async def invocation_id(self) -> str:
Returns the invocation id of the send operation.
"""

@abc.abstractmethod
async def cancel_invocation(self) -> None:
"""
Cancels the invocation.

Just a utility shortcut to:
.. code-block:: python

await ctx.cancel_invocation(await f.invocation_id())
"""

class Context(abc.ABC):
"""
Represents the context of the current invocation.
Expand Down
16 changes: 16 additions & 0 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,22 @@ async def invocation_id(self) -> str:
"""Get the invocation id."""
return await self.invocation_id_future.get()

async def cancel_invocation(self) -> None:
"""
Cancels the invocation.

Just a utility shortcut to:
.. code-block:: python

await ctx.cancel_invocation(await f.invocation_id())
"""

class ServerSendHandle(SendHandle):
"""This class implements the send API"""

def __init__(self, context: "ServerInvocationContext", handle: int) -> None:
super().__init__()
self.context = context

async def coro():
if not context.vm.is_completed(handle):
Expand All @@ -123,6 +134,11 @@ async def invocation_id(self) -> str:
"""Get the invocation id."""
return await self.future

async def cancel_invocation(self) -> None:
"""Cancel the invocation."""
invocation_id = await self.invocation_id()
self.context.cancel_invocation(invocation_id)

async def async_value(n: Callable[[], T]) -> T:
"""convert a simple value to a coroutine."""
return n()
Expand Down