From 8f73eb28f76b56a095134c2fbe19010e4fe5d1d6 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 14 Jan 2025 18:09:14 +0100 Subject: [PATCH] Protocol V4 --- Cargo.lock | 123 +++++++++--------- Cargo.toml | 2 +- README.md | 6 +- python/restate/discovery.py | 4 +- python/restate/server_context.py | 142 +++++++++++++-------- python/restate/vm.py | 62 ++++----- src/lib.rs | 213 ++++++++++++++++++++----------- 7 files changed, 328 insertions(+), 224 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5655b81..d99aa8b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -13,15 +13,15 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.86" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" [[package]] name = "autocfg" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "base64" @@ -55,15 +55,15 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.17.0" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytes" -version = "1.6.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" [[package]] name = "bytes-utils" @@ -77,9 +77,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.13" +version = "1.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7777341816418c02e033934a09f20dc0ccaf65a5201ef8a450ae0105a573fda" +checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" dependencies = [ "shlex", ] @@ -98,9 +98,9 @@ checksum = "68ff6be19477a1bd5441f382916a89bc2a0b2c35db6d41e0f6e8538bf6d6463f" [[package]] name = "cpufeatures" -version = "0.2.17" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" dependencies = [ "libc", ] @@ -178,9 +178,9 @@ checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" [[package]] name = "itertools" -version = "0.12.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" dependencies = [ "either", ] @@ -224,15 +224,15 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.155" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" [[package]] name = "matchers" @@ -304,9 +304,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "overload" @@ -332,15 +332,15 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" [[package]] name = "portable-atomic" -version = "1.6.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" +checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" [[package]] name = "powerfmt" @@ -445,9 +445,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.36" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" dependencies = [ "proc-macro2", ] @@ -463,14 +463,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.5" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.7", - "regex-syntax 0.8.4", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", ] [[package]] @@ -484,13 +484,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.4", + "regex-syntax 0.8.5", ] [[package]] @@ -501,13 +501,13 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "restate-sdk-python-core" -version = "0.5.0" +version = "0.5.1" dependencies = [ "pyo3", "restate-sdk-shared-core", @@ -517,8 +517,7 @@ dependencies = [ [[package]] name = "restate-sdk-shared-core" version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693b460bba5579ddc9047df94e116f65416bcec172e9ec1c06ba5924842eb7ff" +source = "git+https://github.com/restatedev/sdk-shared-core.git?branch=main#33590c3f47f83959d13bf2f59d8eb67d4cc87606" dependencies = [ "base64 0.22.1", "bs58", @@ -552,30 +551,30 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.17" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" +checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" -version = "1.0.19" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "serde" -version = "1.0.204" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" +checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.204" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" +checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", @@ -584,9 +583,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.138" +version = "1.0.135" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" +checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" dependencies = [ "itoa", "memchr", @@ -668,9 +667,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.98" +version = "2.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" +checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" dependencies = [ "proc-macro2", "quote", @@ -679,9 +678,9 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.14" +version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "thiserror" @@ -781,9 +780,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tracing" -version = "0.1.40" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -792,9 +791,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", @@ -803,9 +802,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" dependencies = [ "once_cell", "valuable", @@ -824,9 +823,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.18" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" dependencies = [ "matchers", "nu-ansi-term", @@ -848,9 +847,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-ident" -version = "1.0.12" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" [[package]] name = "unindent" diff --git a/Cargo.toml b/Cargo.toml index 5359195..27876ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,4 +14,4 @@ doc = false [dependencies] pyo3 = { version = "0.22.6", features = ["extension-module"] } tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } -restate-sdk-shared-core = { version = "0.2.0", features = ["request_identity", "sha2_random_seed"] } +restate-sdk-shared-core = { git = "https://github.com/restatedev/sdk-shared-core.git", branch = "main", features = ["request_identity", "sha2_random_seed"] } diff --git a/README.md b/README.md index 139d17a..4152d00 100644 --- a/README.md +++ b/README.md @@ -51,13 +51,13 @@ Setup your virtual environment using the tool of your choice, e.g. VirtualEnv: ```shell python3 -m venv .venv -source venv/bin/activate +source .venv/bin/activate ``` -Install `maturin`: +Install the build tools: ```shell -pip install maturin +pip install -r requirements.txt ``` Now build the Rust module and include opt-in additional dev dependencies: diff --git a/python/restate/discovery.py b/python/restate/discovery.py index 042d988..1935586 100644 --- a/python/restate/discovery.py +++ b/python/restate/discovery.py @@ -192,6 +192,6 @@ def compute_discovery(endpoint: RestateEndpoint, discovered_as : typing.Literal[ else: protocol_mode = PROTOCOL_MODES[discovered_as] return Endpoint(protocolMode=protocol_mode, - minProtocolVersion=2, - maxProtocolVersion=2, + minProtocolVersion=4, + maxProtocolVersion=4, services=services) diff --git a/python/restate/server_context.py b/python/restate/server_context.py index 02570a4..b181661 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -13,7 +13,7 @@ from datetime import timedelta import inspect -from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, cast +from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar import typing import traceback @@ -22,8 +22,8 @@ from restate.handler import Handler, handler_from_callable, invoke_handler from restate.serde import BytesSerde, JsonSerde, Serde from restate.server_types import Receive, Send -from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig - +from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig, CANCEL_HANDLE # pylint: disable=line-too-long +from restate._internal import PyDoProgressAnyCompleted, PyDoProgressReadFromInput, PyDoProgressExecuteRun, PyDoProgressCancelSignalReceived # pylint: disable=import-error,no-name-in-module,line-too-long T = TypeVar('T') I = TypeVar('I') @@ -50,7 +50,7 @@ def __init__(self, server_context, name, serde) -> None: def value(self) -> Awaitable[Any]: vm: VMWrapper = self.server_context.vm handle = vm.sys_get_promise(self.name) - coro = self.server_context.create_poll_coroutine(handle) + coro = self.server_context.create_poll_or_cancel_coroutine(handle) serde = self.serde assert serde is not None @@ -65,18 +65,18 @@ def resolve(self, value: Any) -> Awaitable[None]: assert self.serde is not None value_buffer = self.serde.serialize(value) handle = vm.sys_complete_promise_success(self.name, value_buffer) - return self.server_context.create_poll_coroutine(handle) + return self.server_context.create_poll_or_cancel_coroutine(handle) def reject(self, message: str, code: int = 500) -> Awaitable[None]: vm: VMWrapper = self.server_context.vm py_failure = Failure(code=code, message=message) handle = vm.sys_complete_promise_failure(self.name, py_failure) - return self.server_context.create_poll_coroutine(handle) + return self.server_context.create_poll_or_cancel_coroutine(handle) def peek(self) -> Awaitable[Any | None]: vm: VMWrapper = self.server_context.vm handle = vm.sys_peek_promise(self.name) - coro = self.server_context.create_poll_coroutine(handle) + coro = self.server_context.create_poll_or_cancel_coroutine(handle) serde = self.serde assert serde is not None @@ -109,6 +109,7 @@ def __init__(self, self.attempt_headers = attempt_headers self.send = send self.receive = receive + self.run_coros_to_execute: dict[int, Awaitable[typing.Union[bytes | Failure]]] = {} async def enter(self): """Invoke the user code.""" @@ -125,8 +126,8 @@ async def enter(self): except SuspendedException: pass except Exception as e: - fmt = '\n'.join(traceback.format_exception(e)) - self.vm.notify_error(fmt) + stacktrace = '\n'.join(traceback.format_exception(e)) + self.vm.notify_error(repr(e), stacktrace) raise e async def leave(self): @@ -166,9 +167,8 @@ async def leave(self): 'more_body': False, }) - - async def create_poll_coroutine(self, handle) -> bytes | None: - """Create a coroutine to poll the handle.""" + async def take_and_send_output(self): + """Take output from state machine and send it""" output = self.vm.take_output() if output: await self.send({ @@ -176,10 +176,36 @@ async def create_poll_coroutine(self, handle) -> bytes | None: 'body': bytes(output), 'more_body': True, }) - self.vm.notify_await_point(handle) + + def must_take_notification(self, handle): + """Take notification, which must be present""" + res = self.vm.take_notification(handle) + if isinstance(res, NotReady): + raise ValueError(f"Unexpected value error: {handle}") + if res is None: + return None + if isinstance(res, Failure): + raise TerminalError(res.message, res.code) + return res + + async def create_poll_or_cancel_coroutine(self, handle) -> bytes | None: + """Create a coroutine to poll the handle.""" + await self.take_and_send_output() while True: - res = self.vm.take_async_result(handle) - if isinstance(res, NotReady): + if self.vm.is_completed(handle): + # Handle is completed + return self.must_take_notification(handle) + + # Nothing ready yet, let's try to make some progress + do_progress_response = self.vm.do_progress([handle, CANCEL_HANDLE]) + if isinstance(do_progress_response, PyDoProgressAnyCompleted): + # One of the handles completed, we can continue + continue + if isinstance(do_progress_response, PyDoProgressCancelSignalReceived): + # Raise cancel signal + raise TerminalError("cancelled", 409) + if isinstance(do_progress_response, PyDoProgressReadFromInput): + # We need to read from input chunk = await self.receive() if chunk.get('body', None) is not None: assert isinstance(chunk['body'], bytes) @@ -187,14 +213,13 @@ async def create_poll_coroutine(self, handle) -> bytes | None: if not chunk.get('more_body', False): self.vm.notify_input_closed() continue - if res is None: - return None - if isinstance(res, Failure): - raise TerminalError(res.message, res.code) - return res + if isinstance(do_progress_response, PyDoProgressExecuteRun): + await self.run_coros_to_execute[do_progress_response.handle] + await self.take_and_send_output() + def get(self, name: str, serde: Serde[T] = JsonSerde()) -> typing.Awaitable[Optional[Any]]: - coro = self.create_poll_coroutine(self.vm.sys_get_state(name)) + coro = self.create_poll_or_cancel_coroutine(self.vm.sys_get_state(name)) async def await_point(): """Wait for this handle to be resolved.""" @@ -206,7 +231,7 @@ async def await_point(): return await_point() # do not await here, the caller will do it. def state_keys(self) -> Awaitable[List[str]]: - return self.create_poll_coroutine(self.vm.sys_get_state_keys()) # type: ignore + return self.create_poll_or_cancel_coroutine(self.vm.sys_get_state_keys()) # type: ignore def set(self, name: str, value: T, serde: Serde[T] = JsonSerde()) -> None: """Set the value associated with the given name.""" @@ -227,37 +252,23 @@ def request(self) -> Request: body=self.invocation.input_buffer, ) - # pylint: disable=W0236 - # pylint: disable=R0914 - async def run(self, - name: str, - action: Callable[[], T] | Callable[[], Awaitable[T]], - serde: Optional[Serde[T]] = JsonSerde(), - max_attempts: Optional[int] = None, - max_retry_duration: Optional[timedelta] = None) -> T: - assert serde is not None - res = self.vm.sys_run_enter(name) - if isinstance(res, Failure): - raise TerminalError(res.message, res.code) - if isinstance(res, bytes): - return cast(T, serde.deserialize(res)) - # the side effect was not executed before, so we need to execute it now - assert res is None + async def create_run_coroutine(self, + handle: int, + action: Callable[[], T] | Callable[[], Awaitable[T]], + serde: Serde[T], + max_attempts: Optional[int] = None, + max_retry_duration: Optional[timedelta] = None): + """Create a coroutine to poll the handle.""" try: if inspect.iscoroutinefunction(action): action_result = await action() # type: ignore else: action_result = action() buffer = serde.serialize(action_result) - handle = self.vm.sys_run_exit_success(buffer) - await self.create_poll_coroutine(handle) - return action_result + self.vm.propose_run_completion_success(handle, buffer) except TerminalError as t: failure = Failure(code=t.status_code, message=t.message) - handle = self.vm.sys_run_exit_failure(failure) - await self.create_poll_coroutine(handle) - # unreachable - assert False + self.vm.propose_run_completion_failure(handle, failure) # pylint: disable=W0718 except Exception as e: if max_attempts is None and max_retry_duration is None: @@ -266,17 +277,38 @@ async def run(self, failure = Failure(code=500, message=str(e)) max_duration_ms = None if max_retry_duration is None else int(max_retry_duration.total_seconds() * 1000) config = RunRetryConfig(max_attempts=max_attempts, max_duration=max_duration_ms) - exit_handle = self.vm.sys_run_exit_transient(failure=failure, attempt_duration_ms=1, config=config) - if exit_handle is None: - raise e from None # avoid the traceback that says exception was raised while handling another exception - await self.create_poll_coroutine(exit_handle) - # unreachable - assert False + self.vm.propose_run_completion_transient(handle, failure=failure, attempt_duration_ms=1, config=config) + + # pylint: disable=W0236 + # pylint: disable=R0914 + def run(self, + name: str, + action: Callable[[], T] | Callable[[], Awaitable[T]], + serde: Optional[Serde[T]] = JsonSerde(), + max_attempts: Optional[int] = None, + max_retry_duration: Optional[timedelta] = None) -> Awaitable[T]: + assert serde is not None + handle = self.vm.sys_run(name) + + # Register closure to run + self.run_coros_to_execute[handle] = self.create_run_coroutine(handle, action, serde, max_attempts, max_retry_duration) + + # Prepare response coroutine + coro = self.create_poll_or_cancel_coroutine(handle) + async def await_point(): + """Wait for this handle to be resolved.""" + res = await coro + if res is None: + return None + return serde.deserialize(res) + + return await_point() # do not await here, the caller will do it. + def sleep(self, delta: timedelta) -> Awaitable[None]: # convert timedelta to milliseconds millis = int(delta.total_seconds() * 1000) - return self.create_poll_coroutine(self.vm.sys_sleep(millis)) # type: ignore + return self.create_poll_or_cancel_coroutine(self.vm.sys_sleep(millis)) # type: ignore def do_call(self, tpe: Callable[[Any, I], Awaitable[O]], @@ -319,10 +351,10 @@ def do_raw_call(self, async def await_point(s: ServerInvocationContext, h, o: Serde[O]): """Wait for this handle to be resolved, and deserialize the response.""" - res = await s.create_poll_coroutine(h) + res = await s.create_poll_or_cancel_coroutine(h) return o.deserialize(res) # type: ignore - return await_point(self, handle, output_serde) + return await_point(self, handle.result_handle, output_serde) def service_call(self, tpe: Callable[[Any, I], Awaitable[O]], @@ -368,7 +400,7 @@ def awakeable(self, serde: typing.Optional[Serde[I]] = JsonSerde()) -> typing.Tuple[str, Awaitable[Any]]: assert serde is not None name, handle = self.vm.sys_awakeable() - coro = self.create_poll_coroutine(handle) + coro = self.create_poll_or_cancel_coroutine(handle) async def await_point(): """Wait for this handle to be resolved.""" diff --git a/python/restate/vm.py b/python/restate/vm.py index 54b2551..23feef7 100644 --- a/python/restate/vm.py +++ b/python/restate/vm.py @@ -15,7 +15,7 @@ from dataclasses import dataclass import typing -from restate._internal import PyVM, PyFailure, PySuspended, PyVoid, PyStateKeys, PyExponentialRetryConfig # pylint: disable=import-error,no-name-in-module,line-too-long +from restate._internal import PyVM, PyFailure, PySuspended, PyVoid, PyStateKeys, PyExponentialRetryConfig, PyDoProgressAnyCompleted, PyDoProgressReadFromInput, PyDoProgressExecuteRun, PyDoProgressCancelSignalReceived, CANCEL_NOTIFICATION_HANDLE # pylint: disable=import-error,no-name-in-module,line-too-long @dataclass class Invocation: @@ -60,8 +60,10 @@ def __init__(self, *args: object) -> None: NOT_READY = NotReady() SUSPENDED = SuspendedException() +CANCEL_HANDLE = CANCEL_NOTIFICATION_HANDLE -AsyncResultType = typing.Optional[typing.Union[bytes, Failure, NotReady]] +NotificationType = typing.Optional[typing.Union[bytes, Failure, NotReady, list[str], str]] +DoProgressResult = typing.Union[PyDoProgressAnyCompleted, PyDoProgressReadFromInput, PyDoProgressExecuteRun, PyDoProgressCancelSignalReceived] # pylint: disable=line-too-long # pylint: disable=too-many-public-methods class VMWrapper: @@ -91,25 +93,33 @@ def notify_input_closed(self): """Notify the virtual machine that the input has been closed.""" self.vm.notify_input_closed() - def notify_error(self, error: str): + def notify_error(self, error: str, stacktrace: str): """Notify the virtual machine of an error.""" - self.vm.notify_error(error) + self.vm.notify_error(error, stacktrace) def take_output(self) -> typing.Optional[bytes]: """Take the output from the virtual machine.""" return self.vm.take_output() - def notify_await_point(self, handle: int): - """Notify the virtual machine of an await point.""" - self.vm.notify_await_point(handle) - def is_ready_to_execute(self) -> bool: """Returns true when the VM is ready to operate.""" return self.vm.is_ready_to_execute() - def take_async_result(self, handle: typing.Any) -> AsyncResultType: + def is_completed(self, handle: int) -> bool: + """Returns true when the notification handle is completed and hasn't been taken yet.""" + return self.vm.is_completed(handle) + + def do_progress(self, handles: list[int]) -> DoProgressResult: + """Do progress with notifications.""" + result = self.vm.do_progress(handles) + if isinstance(result, PySuspended): + # the state machine had suspended + raise SUSPENDED + return result + + def take_notification(self, handle: int) -> NotificationType: """Take the result of an asynchronous operation.""" - result = self.vm.take_async_result(handle) + result = self.vm.take_notification(handle) if result is None: return NOT_READY if isinstance(result, PyVoid): @@ -121,6 +131,9 @@ def take_async_result(self, handle: typing.Any) -> AsyncResultType: if isinstance(result, PyStateKeys): # success with state keys return result.keys + if isinstance(result, str): + # success with invocation id + return result if isinstance(result, PyFailure): # a terminal failure code = result.code @@ -244,22 +257,11 @@ def sys_send(self, """send an invocation to a service (no response)""" self.vm.sys_send(service, handler, parameter, key, delay) - def sys_run_enter(self, name: str) -> typing.Union[bytes, None, Failure]: + def sys_run(self, name: str) -> int: """ - Enter a side effect - - Returns: - None if the side effect was not journald. - PyFailure if the side effect failed. - bytes if the side effect was successful. + Register a run """ - result = self.vm.sys_run_enter(name) - if result is None: - return None - if isinstance(result, PyFailure): - return Failure(result.code, result.message) # pylint: disable=protected-access - assert isinstance(result, bytes) - return result + return self.vm.sys_run(name) def sys_awakeable(self) -> typing.Tuple[str, int]: """ @@ -280,7 +282,7 @@ def sys_reject_awakeable(self, name: str, failure: Failure): py_failure = PyFailure(failure.code, failure.message) self.vm.sys_complete_awakeable_failure(name, py_failure) - def sys_run_exit_success(self, output: bytes) -> int: + def propose_run_completion_success(self, handle: int, output: bytes) -> int: """ Exit a side effect @@ -290,7 +292,7 @@ def sys_run_exit_success(self, output: bytes) -> int: Returns: handle """ - return self.vm.sys_run_exit_success(output) + return self.vm.propose_run_completion_success(handle, output) def sys_get_promise(self, name: str) -> int: """Returns the promise handle""" @@ -309,7 +311,7 @@ def sys_complete_promise_failure(self, name: str, failure: Failure) -> int: res = PyFailure(failure.code, failure.message) return self.vm.sys_complete_promise_failure(name, res) - def sys_run_exit_failure(self, output: Failure) -> int: + def propose_run_completion_failure(self, handle: int, output: Failure) -> int: """ Exit a side effect @@ -318,10 +320,10 @@ def sys_run_exit_failure(self, output: Failure) -> int: output: The output of the side effect. """ res = PyFailure(output.code, output.message) - return self.vm.sys_run_exit_failure(res) + return self.vm.propose_run_completion_failure(handle, res) # pylint: disable=line-too-long - def sys_run_exit_transient(self, failure: Failure, attempt_duration_ms: int, config: RunRetryConfig) -> int | None: + def propose_run_completion_transient(self, handle: int, failure: Failure, attempt_duration_ms: int, config: RunRetryConfig) -> int | None: """ Exit a side effect with a transient Error. This requires a retry policy to be provided. @@ -329,7 +331,7 @@ def sys_run_exit_transient(self, failure: Failure, attempt_duration_ms: int, con py_failure = PyFailure(failure.code, failure.message) py_config = PyExponentialRetryConfig(config.initial_interval, config.max_attempts, config.max_duration) try: - handle = self.vm.sys_run_exit_failure_transient(py_failure, attempt_duration_ms, py_config) + handle = self.vm.propose_run_completion_failure_transient(handle, py_failure, attempt_duration_ms, py_config) # The VM decided not to retry, therefore we get back an handle that will be resolved # with a terminal failure. return handle diff --git a/src/lib.rs b/src/lib.rs index 420bf16..d38d1b7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,11 @@ use pyo3::create_exception; use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyNone}; +use pyo3::types::{PyBytes, PyNone, PyString}; use restate_sdk_shared_core::{ - AsyncResultHandle, CoreVM, Header, IdentityVerifier, Input, NonEmptyValue, ResponseHead, - RetryPolicy, RunEnterResult, RunExitResult, SuspendedOrVMError, TakeOutputResult, Target, - TerminalFailure, VMOptions, Value, VM, + CallHandle, CoreVM, DoProgressResponse, Error, Header, IdentityVerifier, Input, NonEmptyValue, + NotificationHandle, ResponseHead, RetryPolicy, RunExitResult, SuspendedOrVMError, + TakeOutputResult, Target, TerminalFailure, VMOptions, Value, CANCEL_NOTIFICATION_HANDLE, VM, }; -use std::borrow::Cow; use std::time::{Duration, SystemTime}; // Current crate version @@ -63,7 +62,7 @@ fn take_output_result_into_py( } } -type PyAsyncResultHandle = u32; +type PyNotificationHandle = u32; #[pyclass] struct PyVoid; @@ -88,12 +87,6 @@ impl PyFailure { } } -impl Into for PyFailure { - fn into(self) -> restate_sdk_shared_core::Error { - restate_sdk_shared_core::Error::new(self.code, self.message) - } -} - #[pyclass] #[derive(Clone)] struct PyExponentialRetryConfig { @@ -152,6 +145,12 @@ impl From for TerminalFailure { } } +impl From for Error { + fn from(value: PyFailure) -> Self { + Self::new(value.code, value.message) + } +} + #[pyclass] #[derive(Clone)] struct PyStateKeys { @@ -185,10 +184,42 @@ impl From for PyInput { } } +#[pyclass] +struct PyDoProgressReadFromInput; + +#[pyclass] +struct PyDoProgressAnyCompleted; + +#[pyclass] +struct PyDoProgressExecuteRun { + #[pyo3(get)] + handle: PyNotificationHandle, +} + +#[pyclass] +struct PyDoProgressCancelSignalReceived; + +#[pyclass] +pub struct PyCallHandle { + #[pyo3(get)] + invocation_id_handle: PyNotificationHandle, + #[pyo3(get)] + result_handle: PyNotificationHandle, +} + +impl From for PyCallHandle { + fn from(value: CallHandle) -> Self { + PyCallHandle { + invocation_id_handle: value.invocation_id_notification_handle.into(), + result_handle: value.call_notification_handle.into(), + } + } +} + // Errors and Exceptions #[derive(Debug)] -struct PyVMError(restate_sdk_shared_core::Error); +struct PyVMError(Error); // Python representation of restate_sdk_shared_core::Error create_exception!( @@ -204,8 +235,8 @@ impl From for PyErr { } } -impl From for PyVMError { - fn from(value: restate_sdk_shared_core::Error) -> Self { +impl From for PyVMError { + fn from(value: Error) -> Self { PyVMError(value) } } @@ -241,13 +272,13 @@ impl PyVM { self_.vm.notify_input_closed(); } - #[pyo3(signature = (error, description=None))] - fn notify_error(mut self_: PyRefMut<'_, Self>, error: String, description: Option) { - let mut e = restate_sdk_shared_core::Error::new(500u16, Cow::Owned(error)); - if let Some(desc) = description { - e = e.with_description(desc); + #[pyo3(signature = (error, stacktrace=None))] + fn notify_error(mut self_: PyRefMut<'_, Self>, error: String, stacktrace: Option) { + let mut error = Error::new(restate_sdk_shared_core::error::codes::INTERNAL, error); + if let Some(desc) = stacktrace { + error = error.with_stacktrace(desc); } - CoreVM::notify_error(&mut self_.vm, e, None); + CoreVM::notify_error(&mut self_.vm, error, None); } // Take(s) @@ -261,8 +292,48 @@ impl PyVM { self_.vm.is_ready_to_execute().map_err(Into::into) } - fn notify_await_point(mut self_: PyRefMut<'_, Self>, handle: PyAsyncResultHandle) { - self_.vm.notify_await_point(handle.into()) + fn is_completed(self_: PyRef<'_, Self>, handle: PyNotificationHandle) -> bool { + self_.vm.is_completed(handle.into()) + } + + fn do_progress( + mut self_: PyRefMut<'_, Self>, + any_handle: Vec, + ) -> Result, PyVMError> { + let res = self_.vm.do_progress( + any_handle + .into_iter() + .map(NotificationHandle::from) + .collect(), + ); + + let py = self_.py(); + + match res { + Err(SuspendedOrVMError::VM(e)) => Err(e.into()), + Err(SuspendedOrVMError::Suspended(_)) => { + Ok(PySuspended.into_py(py).into_bound(py).into_any()) + } + Ok(DoProgressResponse::AnyCompleted) => Ok(PyDoProgressAnyCompleted + .into_py(py) + .into_bound(py) + .into_any()), + Ok(DoProgressResponse::ReadFromInput) => Ok(PyDoProgressReadFromInput + .into_py(py) + .into_bound(py) + .into_any()), + Ok(DoProgressResponse::ExecuteRun(handle)) => Ok(PyDoProgressExecuteRun { + handle: handle.into(), + } + .into_py(py) + .into_bound(py) + .into_any()), + Ok(DoProgressResponse::CancelSignalReceived) => Ok(PyDoProgressCancelSignalReceived + .into_py(py) + .into_bound(py) + .into_any()), + Ok(DoProgressResponse::WaitingPendingRun) => panic!("Python SDK doesn't support concurrent pending runs, so this is not supposed to happen") + } } /// Returns either: @@ -270,13 +341,15 @@ impl PyVM { /// * `PyBytes` in case the async result holds success value /// * `PyFailure` in case the async result holds failure value /// * `PyVoid` in case the async result holds Void value + /// * `PyStateKeys` in case the async result holds StateKeys + /// * `PyString` in case the async result holds invocation id /// * `PySuspended` in case the state machine is suspended /// * `None` in case the async result is not yet present - fn take_async_result( + fn take_notification( mut self_: PyRefMut<'_, Self>, - handle: PyAsyncResultHandle, + handle: PyNotificationHandle, ) -> Result, PyVMError> { - let res = self_.vm.take_async_result(AsyncResultHandle::from(handle)); + let res = self_.vm.take_notification(NotificationHandle::from(handle)); let py = self_.py(); @@ -294,8 +367,8 @@ impl PyVM { Ok(Some(Value::StateKeys(keys))) => { Ok(PyStateKeys { keys }.into_py(py).into_bound(py).into_any()) } - Ok(Some(Value::InvocationId(_))) | Ok(Some(Value::CombinatorResult(_))) => { - panic!("Unsupported variants, the python SDK doesn't support these features yet!") + Ok(Some(Value::InvocationId(invocation_id))) => { + Ok(PyString::new_bound(py, &invocation_id).into_any()) } } } @@ -309,7 +382,7 @@ impl PyVM { fn sys_get_state( mut self_: PyRefMut<'_, Self>, key: String, - ) -> Result { + ) -> Result { self_ .vm .sys_state_get(key) @@ -317,7 +390,9 @@ impl PyVM { .map_err(Into::into) } - fn sys_get_state_keys(mut self_: PyRefMut<'_, Self>) -> Result { + fn sys_get_state_keys( + mut self_: PyRefMut<'_, Self>, + ) -> Result { self_ .vm .sys_state_get_keys() @@ -347,7 +422,7 @@ impl PyVM { fn sys_sleep( mut self_: PyRefMut<'_, Self>, millis: u64, - ) -> Result { + ) -> Result { let now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .expect("Duration since unix epoch cannot fail"); @@ -365,7 +440,7 @@ impl PyVM { handler: String, buffer: &Bound<'_, PyBytes>, key: Option, - ) -> Result { + ) -> Result { self_ .vm .sys_call( @@ -390,7 +465,7 @@ impl PyVM { buffer: &Bound<'_, PyBytes>, key: Option, delay: Option, - ) -> Result<(), PyVMError> { + ) -> Result { self_ .vm .sys_send( @@ -409,13 +484,13 @@ impl PyVM { + Duration::from_millis(millis) }), ) - .map(|_| ()) + .map(|s| s.invocation_id_notification_handle.into()) .map_err(Into::into) } fn sys_awakeable( mut self_: PyRefMut<'_, Self>, - ) -> Result<(String, PyAsyncResultHandle), PyVMError> { + ) -> Result<(String, PyNotificationHandle), PyVMError> { self_ .vm .sys_awakeable() @@ -451,7 +526,7 @@ impl PyVM { fn sys_get_promise( mut self_: PyRefMut<'_, Self>, key: String, - ) -> Result { + ) -> Result { self_ .vm .sys_get_promise(key) @@ -462,7 +537,7 @@ impl PyVM { fn sys_peek_promise( mut self_: PyRefMut<'_, Self>, key: String, - ) -> Result { + ) -> Result { self_ .vm .sys_peek_promise(key) @@ -474,7 +549,7 @@ impl PyVM { mut self_: PyRefMut<'_, Self>, key: String, buffer: &Bound<'_, PyBytes>, - ) -> Result { + ) -> Result { self_ .vm .sys_complete_promise( @@ -489,7 +564,7 @@ impl PyVM { mut self_: PyRefMut<'_, Self>, key: String, value: PyFailure, - ) -> Result { + ) -> Result { self_ .vm .sys_complete_promise(key, NonEmptyValue::Failure(value.into())) @@ -497,73 +572,60 @@ impl PyVM { .map_err(Into::into) } - /// Returns either: - /// - /// * `PyBytes`, in case the run was executed with success - /// * `PyFailure`, in case the run was executed with failure - /// * `None` in case the run was not executed - fn sys_run_enter( + /// Returns the associated `PyNotificationHandle`. + fn sys_run( mut self_: PyRefMut<'_, Self>, name: String, - ) -> Result, PyVMError> { - let result = self_.vm.sys_run_enter(name)?; - - let py = self_.py(); - - Ok(match result { - RunEnterResult::Executed(NonEmptyValue::Success(b)) => { - PyBytes::new_bound(py, &b).into_any() - } - RunEnterResult::Executed(NonEmptyValue::Failure(f)) => { - PyFailure::from(f).into_py(py).into_bound(py).into_any() - } - RunEnterResult::NotExecuted(_retry_info) => PyNone::get_bound(py).to_owned().into_any(), - }) + ) -> Result { + self_.vm.sys_run(name).map(Into::into).map_err(Into::into) } - fn sys_run_exit_success( + fn propose_run_completion_success( mut self_: PyRefMut<'_, Self>, + handle: PyNotificationHandle, buffer: &Bound<'_, PyBytes>, - ) -> Result { - CoreVM::sys_run_exit( + ) -> Result<(), PyVMError> { + CoreVM::propose_run_completion( &mut self_.vm, + handle.into(), RunExitResult::Success(buffer.as_bytes().to_vec().into()), RetryPolicy::None, ) - .map(Into::into) .map_err(Into::into) } - fn sys_run_exit_failure( + fn propose_run_completion_failure( mut self_: PyRefMut<'_, Self>, + handle: PyNotificationHandle, value: PyFailure, - ) -> Result { + ) -> Result<(), PyVMError> { self_ .vm - .sys_run_exit( + .propose_run_completion( + handle.into(), RunExitResult::TerminalFailure(value.into()), RetryPolicy::None, ) - .map(Into::into) .map_err(Into::into) } - fn sys_run_exit_failure_transient( + fn propose_run_completion_failure_transient( mut self_: PyRefMut<'_, Self>, + handle: PyNotificationHandle, value: PyFailure, attempt_duration: u64, config: PyExponentialRetryConfig, - ) -> Result { + ) -> Result<(), PyVMError> { self_ .vm - .sys_run_exit( + .propose_run_completion( + handle.into(), RunExitResult::RetryableFailure { attempt_duration: Duration::from_millis(attempt_duration), error: value.into(), }, config.into(), ) - .map(Into::into) .map_err(Into::into) } @@ -654,6 +716,11 @@ fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add("VMException", m.py().get_type_bound::())?; m.add( @@ -665,5 +732,9 @@ fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> { m.py().get_type_bound::(), )?; m.add("SDK_VERSION", CURRENT_VERSION)?; + m.add( + "CANCEL_NOTIFICATION_HANDLE", + PyNotificationHandle::from(CANCEL_NOTIFICATION_HANDLE), + )?; Ok(()) }