Skip to content

Commit 0fb333d

Browse files
Use task cancellation instead of custom suspension exception (#132)
1 parent 9619a96 commit 0fb333d

File tree

3 files changed

+80
-36
lines changed

3 files changed

+80
-36
lines changed

python/restate/server_context.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from restate.handler import Handler, handler_from_callable, invoke_handler
3434
from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde
3535
from restate.server_types import ReceiveChannel, Send
36-
from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig # pylint: disable=line-too-long
36+
from restate.vm import Failure, Invocation, NotReady, VMWrapper, RunRetryConfig, Suspended # pylint: disable=line-too-long
3737
from restate.vm import DoProgressAnyCompleted, DoProgressCancelSignalReceived, DoProgressReadFromInput, DoProgressExecuteRun, DoWaitPendingRun
3838
import typing_extensions
3939

@@ -160,7 +160,7 @@ def __init__(self, context: "ServerInvocationContext", handle: int) -> None:
160160
async def coro() -> str:
161161
if not context.vm.is_completed(handle):
162162
await context.create_poll_or_cancel_coroutine([handle])
163-
invocation_id = context.must_take_notification(handle)
163+
invocation_id = await context.must_take_notification(handle)
164164
return typing.cast(str, invocation_id)
165165

166166
self.future = LazyFuture(coro)
@@ -200,7 +200,7 @@ def resolve(self, value: Any) -> Awaitable[None]:
200200
async def await_point():
201201
if not self.server_context.vm.is_completed(handle):
202202
await self.server_context.create_poll_or_cancel_coroutine([handle])
203-
self.server_context.must_take_notification(handle)
203+
await self.server_context.must_take_notification(handle)
204204

205205
return ServerDurableFuture(self.server_context, handle, await_point)
206206

@@ -213,7 +213,7 @@ def reject(self, message: str, code: int = 500) -> Awaitable[None]:
213213
async def await_point():
214214
if not self.server_context.vm.is_completed(handle):
215215
await self.server_context.create_poll_or_cancel_coroutine([handle])
216-
self.server_context.must_take_notification(handle)
216+
await self.server_context.must_take_notification(handle)
217217

218218
return ServerDurableFuture(self.server_context, handle, await_point)
219219

@@ -273,6 +273,19 @@ def update_restate_context_is_replaying(vm: VMWrapper):
273273
"""Update the context var 'restate_context_is_replaying'. This should be called after each vm.sys_*"""
274274
restate_context_is_replaying.set(vm.is_replaying())
275275

276+
async def cancel_current_task():
277+
"""Cancel the current task"""
278+
current_task = asyncio.current_task()
279+
if current_task is not None:
280+
# Cancel through asyncio API
281+
current_task.cancel(
282+
"Cancelled by Restate SDK, you should not call any Context method after this exception is thrown."
283+
)
284+
# Sleep 0 will pop up the cancellation
285+
await asyncio.sleep(0)
286+
else:
287+
raise asyncio.CancelledError("Cancelled by Restate SDK, you should not call any Context method after this exception is thrown.")
288+
276289
# pylint: disable=R0902
277290
class ServerInvocationContext(ObjectContext):
278291
"""This class implements the context for the restate framework based on the server."""
@@ -312,7 +325,7 @@ async def enter(self):
312325
self.vm.sys_write_output_failure(failure)
313326
self.vm.sys_end()
314327
# pylint: disable=W0718
315-
except SuspendedException:
328+
except asyncio.CancelledError:
316329
pass
317330
except DisconnectedException:
318331
raise
@@ -372,9 +385,19 @@ async def take_and_send_output(self):
372385
'more_body': True,
373386
})
374387

375-
def must_take_notification(self, handle):
388+
async def must_take_notification(self, handle):
376389
"""Take notification, which must be present"""
377390
res = self.vm.take_notification(handle)
391+
if isinstance(res, Exception):
392+
# We might need to write out something at this point.
393+
await self.take_and_send_output()
394+
# Print this exception, might be relevant for the user
395+
traceback.print_exception(res)
396+
await cancel_current_task()
397+
if isinstance(res, Suspended):
398+
# We might need to write out something at this point.
399+
await self.take_and_send_output()
400+
await cancel_current_task()
378401
if isinstance(res, NotReady):
379402
raise ValueError(f"Unexpected value error: {handle}")
380403
if res is None:
@@ -383,12 +406,21 @@ def must_take_notification(self, handle):
383406
raise TerminalError(res.message, res.code)
384407
return res
385408

386-
387409
async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> None:
388410
"""Create a coroutine to poll the handle."""
389411
await self.take_and_send_output()
390412
while True:
391413
do_progress_response = self.vm.do_progress(handles)
414+
if isinstance(do_progress_response, Exception):
415+
# We might need to write out something at this point.
416+
await self.take_and_send_output()
417+
# Print this exception, might be relevant for the user
418+
traceback.print_exception(do_progress_response)
419+
await cancel_current_task()
420+
if isinstance(do_progress_response, Suspended):
421+
# We might need to write out something at this point.
422+
await self.take_and_send_output()
423+
await cancel_current_task()
392424
if isinstance(do_progress_response, DoProgressAnyCompleted):
393425
# One of the handles completed
394426
return
@@ -425,7 +457,7 @@ def _create_fetch_result_coroutine(self, handle: int, serde: Serde[T] | None = N
425457
async def fetch_result():
426458
if not self.vm.is_completed(handle):
427459
await self.create_poll_or_cancel_coroutine([handle])
428-
res = self.must_take_notification(handle)
460+
res = await self.must_take_notification(handle)
429461
if res is None or serde is None:
430462
return res
431463
if isinstance(res, bytes):
@@ -443,15 +475,15 @@ def create_sleep_future(self, handle: int) -> ServerDurableSleepFuture:
443475
async def transform():
444476
if not self.vm.is_completed(handle):
445477
await self.create_poll_or_cancel_coroutine([handle])
446-
self.must_take_notification(handle)
478+
await self.must_take_notification(handle)
447479
return ServerDurableSleepFuture(self, handle, transform)
448480

449481
def create_call_future(self, handle: int, invocation_id_handle: int, serde: Serde[T] | None = None) -> ServerCallDurableFuture[T]:
450482
"""Create a durable future."""
451483
async def inv_id_factory():
452484
if not self.vm.is_completed(invocation_id_handle):
453485
await self.create_poll_or_cancel_coroutine([invocation_id_handle])
454-
return self.must_take_notification(invocation_id_handle)
486+
return await self.must_take_notification(invocation_id_handle)
455487

456488
return ServerCallDurableFuture(self, handle, self._create_fetch_result_coroutine(handle, serde), inv_id_factory)
457489

python/restate/vm.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from dataclasses import dataclass
1919
import typing
20-
from restate._internal import PyVM, PyHeader, PyFailure, PySuspended, PyVoid, PyStateKeys, PyExponentialRetryConfig, PyDoProgressAnyCompleted, PyDoProgressReadFromInput, PyDoProgressExecuteRun, PyDoWaitForPendingRun, PyDoProgressCancelSignalReceived, CANCEL_NOTIFICATION_HANDLE # pylint: disable=import-error,no-name-in-module,line-too-long
20+
from restate._internal import PyVM, PyHeader, PyFailure, VMException, PySuspended, PyVoid, PyStateKeys, PyExponentialRetryConfig, PyDoProgressAnyCompleted, PyDoProgressReadFromInput, PyDoProgressExecuteRun, PyDoWaitForPendingRun, PyDoProgressCancelSignalReceived, CANCEL_NOTIFICATION_HANDLE # pylint: disable=import-error,no-name-in-module,line-too-long
2121

2222
@dataclass
2323
class Invocation:
@@ -53,19 +53,18 @@ class NotReady:
5353
NotReady
5454
"""
5555

56-
class SuspendedException(Exception):
57-
"""
58-
Suspended Exception
59-
"""
60-
def __init__(self, *args: object) -> None:
61-
super().__init__(*args)
62-
6356
NOT_READY = NotReady()
64-
SUSPENDED = SuspendedException()
6557
CANCEL_HANDLE = CANCEL_NOTIFICATION_HANDLE
6658

6759
NotificationType = typing.Optional[typing.Union[bytes, Failure, NotReady, list[str], str]]
6860

61+
class Suspended:
62+
"""
63+
Represents a suspended error
64+
"""
65+
66+
SUSPENDED = Suspended()
67+
6968
class DoProgressAnyCompleted:
7069
"""
7170
Represents a notification that any of the handles has completed.
@@ -151,11 +150,16 @@ def is_completed(self, handle: int) -> bool:
151150
"""Returns true when the notification handle is completed and hasn't been taken yet."""
152151
return self.vm.is_completed(handle)
153152

154-
def do_progress(self, handles: list[int]) -> DoProgressResult:
153+
# pylint: disable=R0911
154+
def do_progress(self, handles: list[int]) \
155+
-> typing.Union[DoProgressResult, Exception, Suspended]:
155156
"""Do progress with notifications."""
156-
result = self.vm.do_progress(handles)
157+
try:
158+
result = self.vm.do_progress(handles)
159+
except VMException as e:
160+
return e
157161
if isinstance(result, PySuspended):
158-
raise SUSPENDED
162+
return SUSPENDED
159163
if isinstance(result, PyDoProgressAnyCompleted):
160164
return DO_PROGRESS_ANY_COMPLETED
161165
if isinstance(result, PyDoProgressReadFromInput):
@@ -166,11 +170,17 @@ def do_progress(self, handles: list[int]) -> DoProgressResult:
166170
return DO_PROGRESS_CANCEL_SIGNAL_RECEIVED
167171
if isinstance(result, PyDoWaitForPendingRun):
168172
return DO_WAIT_PENDING_RUN
169-
raise ValueError(f"Unknown progress type: {result}")
173+
return ValueError(f"Unknown progress type: {result}")
170174

171-
def take_notification(self, handle: int) -> NotificationType:
175+
def take_notification(self, handle: int) \
176+
-> typing.Union[NotificationType, Exception, Suspended]:
172177
"""Take the result of an asynchronous operation."""
173-
result = self.vm.take_notification(handle)
178+
try:
179+
result = self.vm.take_notification(handle)
180+
except VMException as e:
181+
return e
182+
if isinstance(result, PySuspended):
183+
return SUSPENDED
174184
if result is None:
175185
return NOT_READY
176186
if isinstance(result, PyVoid):
@@ -190,10 +200,7 @@ def take_notification(self, handle: int) -> NotificationType:
190200
code = result.code
191201
message = result.message
192202
return Failure(code, message)
193-
if isinstance(result, PySuspended):
194-
# the state machine had suspended
195-
raise SUSPENDED
196-
raise ValueError(f"Unknown result type: {result}")
203+
return ValueError(f"Unknown result type: {result}")
197204

198205
def sys_input(self) -> Invocation:
199206
"""

src/lib.rs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ create_exception!(
249249
restate_sdk_python_core,
250250
VMException,
251251
pyo3::exceptions::PyException,
252-
"Restate VM exception."
252+
"Protocol state machine exception."
253253
);
254254

255255
impl From<PyVMError> for PyErr {
@@ -761,21 +761,26 @@ impl ErrorFormatter for PythonErrorFormatter {
761761
fn display_closed_error(&self, f: &mut fmt::Formatter<'_>, event: &str) -> fmt::Result {
762762
write!(f, "Execution is suspended, but the handler is still attempting to make progress (calling '{event}'). This can happen:
763763
764-
* If the SuspendedException is caught. Make sure you NEVER catch the SuspendedException, e.g. avoid:
764+
* If you don't need to handle task cancellation, just avoid catch all statements. Don't do:
765765
try:
766766
# Code
767767
except:
768-
# This catches all exceptions, including the SuspendedException!
768+
# This catches all exceptions, including the asyncio.CancelledError!
769+
# '{event}' <- This operation prints this exception
769770
770-
And use instead:
771+
Do instead:
771772
try:
772773
# Code
773774
except TerminalException:
774-
# In Restate handlers you typically want to catch TerminalException
775+
# In Restate handlers you typically want to catch TerminalException only
775776
776-
Check https://docs.restate.dev/develop/python/durable-steps#run for more details on run error handling.
777+
* To catch ctx.run/ctx.run_typed errors, check https://docs.restate.dev/develop/python/durable-steps#run for more details.
777778
778-
* If you use the context after the handler completed, e.g. moving the context to another thread. Check https://docs.restate.dev/develop/python/concurrent-tasks for more details on how to create durable concurrent tasks in Python.")
779+
* If the asyncio.CancelledError is caught, you must not run any Context operation in the except arm.
780+
Check https://docs.python.org/3/library/asyncio-task.html#task-cancellation for more details on task cancellation.
781+
782+
* If you use the context after the handler completed, e.g. moving the context to another thread.
783+
Check https://docs.restate.dev/develop/python/concurrent-tasks for more details on how to create durable concurrent tasks in Python.")
779784
}
780785
}
781786

0 commit comments

Comments
 (0)