Skip to content

Commit f15c445

Browse files
committed
Add a shortcut f.cancel_invocation()
1 parent 00e7d76 commit f15c445

File tree

3 files changed

+70
-28
lines changed

3 files changed

+70
-28
lines changed

examples/concurrent_greeter.py

+32-28
Original file line numberDiff line numberDiff line change
@@ -13,48 +13,52 @@
1313
# pylint: disable=W0613
1414
# pylint: disable=C0115
1515
# pylint: disable=R0903
16+
# pylint: disable=C0301
1617

17-
from datetime import timedelta
18+
import typing
1819

1920
from pydantic import BaseModel
20-
from restate import Service, Context
21-
from restate import wait_completed, RestateDurableSleepFuture, RestateDurableCallFuture
22-
23-
from greeter import greet as g
21+
from restate import wait_completed, Service, Context
2422

2523
# models
2624
class GreetingRequest(BaseModel):
2725
name: str
2826

2927
class Greeting(BaseModel):
30-
message: str
28+
messages: typing.List[str]
3129

32-
# service
30+
class Message(BaseModel):
31+
role: str
32+
content: str
3333

3434
concurrent_greeter = Service("concurrent_greeter")
3535

36-
3736
@concurrent_greeter.handler()
3837
async def greet(ctx: Context, req: GreetingRequest) -> Greeting:
39-
g1 = ctx.service_call(g, arg="1")
40-
g2 = ctx.service_call(g, arg="2")
41-
g3 = ctx.sleep(timedelta(milliseconds=100))
42-
43-
done, pending = await wait_completed(g1, g2, g3)
44-
45-
for f in done:
46-
if isinstance(f, RestateDurableSleepFuture):
47-
print("Timeout :(x", flush=True)
48-
elif isinstance(f, RestateDurableCallFuture):
49-
# the result should be ready.
50-
print(await f)
51-
#
52-
# let's cancel the pending calls then
53-
#
38+
claude_sonnet.as_handler(ctx)
39+
claude = ctx.service_call(claude_sonnet, arg=Message(role="user", content=f"please greet {req.name}"))
40+
openai = ctx.service_call(open_ai, arg=Message(role="user", content=f"please greet {req.name}"))
41+
42+
pending, done = await wait_completed(claude, openai)
43+
44+
# collect the completed greetings
45+
greetings = [await f for f in done]
46+
47+
# cancel the pending calls
5448
for f in pending:
55-
if isinstance(f, RestateDurableCallFuture):
56-
inv = await f.invocation_id()
57-
print(f"Canceling {inv}", flush=True)
58-
ctx.cancel_invocation(inv)
49+
await f.cancel_invocation() # type: ignore
50+
51+
return Greeting(messages=greetings)
52+
53+
54+
# not really interesting, just for this demo:
55+
56+
@concurrent_greeter.handler()
57+
async def claude_sonnet(ctx: Context, req: Message) -> str:
58+
return f"Bonjour {req.content[13:]}!"
59+
60+
@concurrent_greeter.handler()
61+
async def open_ai(ctx: Context, req: Message) -> str:
62+
return f"Hello {req.content[13:]}!"
63+
5964

60-
return Greeting(message=f"Hello {req.name}!")

python/restate/context.py

+22
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,17 @@ async def invocation_id(self) -> str:
5151
Returns the invocation id of the call.
5252
"""
5353

54+
@abc.abstractmethod
55+
async def cancel_invocation(self) -> None:
56+
"""
57+
Cancels the invocation.
58+
59+
Just a utility shortcut to:
60+
.. code-block:: python
61+
62+
await ctx.cancel_invocation(await f.invocation_id())
63+
"""
64+
5465

5566
class RestateDurableSleepFuture(RestateDurableFuture[None]):
5667
"""
@@ -136,6 +147,17 @@ async def invocation_id(self) -> str:
136147
Returns the invocation id of the send operation.
137148
"""
138149

150+
@abc.abstractmethod
151+
async def cancel_invocation(self) -> None:
152+
"""
153+
Cancels the invocation.
154+
155+
Just a utility shortcut to:
156+
.. code-block:: python
157+
158+
await ctx.cancel_invocation(await f.invocation_id())
159+
"""
160+
139161
class Context(abc.ABC):
140162
"""
141163
Represents the context of the current invocation.

python/restate/server_context.py

+16
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,22 @@ async def invocation_id(self) -> str:
106106
"""Get the invocation id."""
107107
return await self.invocation_id_future.get()
108108

109+
async def cancel_invocation(self) -> None:
110+
"""
111+
Cancels the invocation.
112+
113+
Just a utility shortcut to:
114+
.. code-block:: python
115+
116+
await ctx.cancel_invocation(await f.invocation_id())
117+
"""
118+
109119
class ServerSendHandle(SendHandle):
110120
"""This class implements the send API"""
111121

112122
def __init__(self, context: "ServerInvocationContext", handle: int) -> None:
113123
super().__init__()
124+
self.context = context
114125

115126
async def coro():
116127
if not context.vm.is_completed(handle):
@@ -123,6 +134,11 @@ async def invocation_id(self) -> str:
123134
"""Get the invocation id."""
124135
return await self.future
125136

137+
async def cancel_invocation(self) -> None:
138+
"""Cancel the invocation."""
139+
invocation_id = await self.invocation_id()
140+
self.context.cancel_invocation(invocation_id)
141+
126142
async def async_value(n: Callable[[], T]) -> T:
127143
"""convert a simple value to a coroutine."""
128144
return n()

0 commit comments

Comments
 (0)