Skip to content

Add side effect retry #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ doc = false
[dependencies]
pyo3 = { version = "0.22.0", features = ["extension-module"] }
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
restate-sdk-shared-core = "0.0.5"
restate-sdk-shared-core = "0.1.0"
bytes = "1.6.0"
17 changes: 15 additions & 2 deletions python/restate/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
# pylint: disable=R0913,C0301
"""
Restate Context
"""
Expand All @@ -25,6 +26,7 @@

RunAction = Union[Callable[[], T], Callable[[], Awaitable[T]]]


@dataclass
class Request:
"""
Expand Down Expand Up @@ -79,7 +81,6 @@ def clear(self, name: str) -> None:
def clear_all(self) -> None:
"""clear all the values in the store."""


class Context(abc.ABC):
"""
Represents the context of the current invocation.
Expand All @@ -95,9 +96,21 @@ def request(self) -> Request:
def run(self,
name: str,
action: RunAction[T],
serde: Serde[T] = JsonSerde()) -> Awaitable[T]:
serde: Serde[T] = JsonSerde(),
max_attempts: typing.Optional[int] = None,
max_retry_duration: typing.Optional[timedelta] = None) -> Awaitable[T]:
"""
Runs the given action with the given name.

Args:
name: The name of the action.
action: The action to run.
serde: The serialization/deserialization mechanism.
max_attempts: The maximum number of retry attempts to complete the action.
If None, the action will be retried indefinitely, until it succeeds.
Otherwise, the action will be retried until the maximum number of attempts is reached and then it will raise a TerminalError.
max_retry_duration: The maximum duration for retrying. If None, the action will be retried indefinitely, until it succeeds.
Otherwise, the action will be retried until the maximum duration is reached and then it will raise a TerminalError.
"""

@abc.abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions python/restate/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,6 @@ def compute_discovery(endpoint: RestateEndpoint, discovered_as : typing.Literal[
else:
protocol_mode = PROTOCOL_MODES[discovered_as]
return Endpoint(protocolMode=protocol_mode,
minProtocolVersion=1,
maxProtocolVersion=1,
minProtocolVersion=2,
maxProtocolVersion=2,
services=services)
21 changes: 19 additions & 2 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from restate.handler import Handler, handler_from_callable, invoke_handler
from restate.serde import BytesSerde, JsonSerde, Serde
from restate.server_types import Receive, Send
from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper
from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig


T = TypeVar('T')
Expand Down Expand Up @@ -227,10 +227,13 @@ def request(self) -> Request:
)

# pylint: disable=W0236
# pylint: disable=R0914
async def run(self,
name: str,
action: Callable[[], T] | Callable[[], Awaitable[T]],
serde: Optional[Serde[T]] = JsonSerde()) -> T:
serde: Optional[Serde[T]] = JsonSerde(),
max_attempts: Optional[int] = None,
max_retry_duration: Optional[timedelta] = None) -> T:
assert serde is not None
res = self.vm.sys_run_enter(name)
if isinstance(res, Failure):
Expand All @@ -254,6 +257,20 @@ async def run(self,
await self.create_poll_coroutine(handle)
# unreachable
assert False
# pylint: disable=W0718
except Exception as e:
if max_attempts is None and max_retry_duration is None:
# no retry policy
raise e
failure = Failure(code=500, message=str(e))
max_duration_ms = None if max_retry_duration is None else int(max_retry_duration.total_seconds() * 1000)
config = RunRetryConfig(max_attempts=max_attempts, max_duration=max_duration_ms)
exit_handle = self.vm.sys_run_exit_transient(failure=failure, attempt_duration_ms=1, config=config)
if exit_handle is None:
raise e from None # avoid the traceback that says exception was raised while handling another exception
await self.create_poll_coroutine(exit_handle)
# unreachable
assert False

def sleep(self, delta: timedelta) -> Awaitable[None]:
# convert timedelta to milliseconds
Expand Down
28 changes: 27 additions & 1 deletion python/restate/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from dataclasses import dataclass
import typing
from restate._internal import PyVM, PyFailure, PySuspended, PyVoid, PyStateKeys # pylint: disable=import-error,no-name-in-module
from restate._internal import PyVM, PyFailure, PySuspended, PyVoid, PyStateKeys, PyExponentialRetryConfig # pylint: disable=import-error,no-name-in-module,line-too-long

@dataclass
class Invocation:
Expand All @@ -28,6 +28,14 @@ class Invocation:
input_buffer: bytes
key: str

@dataclass
class RunRetryConfig:
"""
Expo Retry Configuration
"""
initial_interval: typing.Optional[int] = None
max_attempts: typing.Optional[int] = None
max_duration: typing.Optional[int] = None

@dataclass
class Failure:
Expand Down Expand Up @@ -312,6 +320,24 @@ def sys_run_exit_failure(self, output: Failure) -> int:
res = PyFailure(output.code, output.message)
return self.vm.sys_run_exit_failure(res)

# pylint: disable=line-too-long
def sys_run_exit_transient(self, failure: Failure, attempt_duration_ms: int, config: RunRetryConfig) -> int | None:
"""
Exit a side effect with a transient Error.
This requires a retry policy to be provided.
"""
py_failure = PyFailure(failure.code, failure.message)
py_config = PyExponentialRetryConfig(config.initial_interval, config.max_attempts, config.max_duration)
try:
handle = self.vm.sys_run_exit_failure_transient(py_failure, attempt_duration_ms, py_config)
# The VM decided not to retry, therefore we get back an handle that will be resolved
# with a terminal failure.
return handle
# pylint: disable=bare-except
except:
# The VM decided to retry, therefore we tear down the current execution
return None
Comment on lines +337 to +339
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not rethrowing (as you were doing before essentially)?


def sys_end(self):
"""
This method is responsible for ending the system.
Expand Down
104 changes: 90 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyNone};
use restate_sdk_shared_core::{
AsyncResultHandle, CoreVM, Failure, Header, IdentityVerifier, Input, NonEmptyValue,
ResponseHead, RunEnterResult, SuspendedOrVMError, TakeOutputResult, Target, VMError, Value, VM,
ResponseHead, RetryPolicy, RunEnterResult, RunExitResult, SuspendedOrVMError, TakeOutputResult,
Target, VMError, Value, VM,
};
use std::borrow::Cow;
use std::time::Duration;
Expand Down Expand Up @@ -87,6 +88,46 @@ impl PyFailure {
}
}

#[pyclass]
#[derive(Clone)]
struct PyExponentialRetryConfig {
#[pyo3(get, set)]
initial_interval: Option<u64>,
#[pyo3(get, set)]
max_attempts: Option<u32>,
#[pyo3(get, set)]
max_duration: Option<u64>,
}

#[pymethods]
impl PyExponentialRetryConfig {
#[pyo3(signature = (initial_interval=None, max_attempts=None, max_duration=None))]
#[new]
fn new(
initial_interval: Option<u64>,
max_attempts: Option<u32>,
max_duration: Option<u64>,
) -> Self {
Self {
initial_interval,
max_attempts,
max_duration,
}
}
}

impl From<PyExponentialRetryConfig> for RetryPolicy {
fn from(value: PyExponentialRetryConfig) -> Self {
RetryPolicy::Exponential {
initial_interval: Duration::from_millis(value.initial_interval.unwrap_or(10)),
max_attempts: value.max_attempts,
max_duration: value.max_duration.map(Duration::from_millis),
factor: 2.0,
max_interval: None,
Comment on lines +125 to +126
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not exposing these two as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll expose when we'd need them

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll need them for the e2e if you want the test to complete quickly :) https://github.com/restatedev/sdk-rust/blob/main/test-services/src/failing.rs#L122

Also I think in particular max_interval is important, otherwise the retry delay with exponent 2 quickly evolves into some too long delay!

}
}
}

impl From<Failure> for PyFailure {
fn from(value: Failure) -> Self {
PyFailure {
Expand Down Expand Up @@ -133,7 +174,7 @@ impl From<Input> for PyInput {
random_seed: value.random_seed,
key: value.key,
headers: value.headers.into_iter().map(Into::into).collect(),
input: value.input,
input: value.input.into(),
}
}
}
Expand Down Expand Up @@ -186,7 +227,8 @@ impl PyVM {
// Notifications

fn notify_input(mut self_: PyRefMut<'_, Self>, buffer: &Bound<'_, PyBytes>) {
self_.vm.notify_input(buffer.as_bytes().to_vec());
let buf = buffer.as_bytes().to_vec().into();
self_.vm.notify_input(buf);
}

fn notify_input_closed(mut self_: PyRefMut<'_, Self>) {
Expand All @@ -195,9 +237,11 @@ impl PyVM {

#[pyo3(signature = (error, description=None))]
fn notify_error(mut self_: PyRefMut<'_, Self>, error: String, description: Option<String>) {
self_.vm.notify_error(
CoreVM::notify_error(
&mut self_.vm,
Cow::Owned(error),
description.map(Cow::Owned).unwrap_or(Cow::Borrowed("")),
None,
);
}

Expand Down Expand Up @@ -280,7 +324,7 @@ impl PyVM {
) -> Result<(), PyVMError> {
self_
.vm
.sys_state_set(key, buffer.as_bytes().to_vec())
.sys_state_set(key, buffer.as_bytes().to_vec().into())
.map_err(Into::into)
}

Expand Down Expand Up @@ -319,7 +363,7 @@ impl PyVM {
handler,
key,
},
buffer.as_bytes().to_vec(),
buffer.as_bytes().to_vec().into(),
)
.map(Into::into)
.map_err(Into::into)
Expand All @@ -342,7 +386,7 @@ impl PyVM {
handler,
key,
},
buffer.as_bytes().to_vec(),
buffer.as_bytes().to_vec().into(),
delay.map(Duration::from_millis),
)
.map_err(Into::into)
Expand All @@ -365,7 +409,10 @@ impl PyVM {
) -> Result<(), PyVMError> {
self_
.vm
.sys_complete_awakeable(id, NonEmptyValue::Success(buffer.as_bytes().to_vec()))
.sys_complete_awakeable(
id,
NonEmptyValue::Success(buffer.as_bytes().to_vec().into()),
)
.map_err(Into::into)
}

Expand Down Expand Up @@ -409,7 +456,10 @@ impl PyVM {
) -> Result<PyAsyncResultHandle, PyVMError> {
self_
.vm
.sys_complete_promise(key, NonEmptyValue::Success(buffer.as_bytes().to_vec()))
.sys_complete_promise(
key,
NonEmptyValue::Success(buffer.as_bytes().to_vec().into()),
)
.map(Into::into)
.map_err(Into::into)
}
Expand Down Expand Up @@ -446,28 +496,52 @@ impl PyVM {
RunEnterResult::Executed(NonEmptyValue::Failure(f)) => {
PyFailure::from(f).into_py(py).into_bound(py).into_any()
}
RunEnterResult::NotExecuted => PyNone::get_bound(py).to_owned().into_any(),
RunEnterResult::NotExecuted(_retry_info) => PyNone::get_bound(py).to_owned().into_any(),
})
}

fn sys_run_exit_success(
mut self_: PyRefMut<'_, Self>,
buffer: &Bound<'_, PyBytes>,
) -> Result<PyAsyncResultHandle, PyVMError> {
CoreVM::sys_run_exit(
&mut self_.vm,
RunExitResult::Success(buffer.as_bytes().to_vec().into()),
RetryPolicy::None,
)
.map(Into::into)
.map_err(Into::into)
}

fn sys_run_exit_failure(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps to avoid any confusion is better to rename this to sys_run_exit_failure_terminal

mut self_: PyRefMut<'_, Self>,
value: PyFailure,
) -> Result<PyAsyncResultHandle, PyVMError> {
self_
.vm
.sys_run_exit(NonEmptyValue::Success(buffer.as_bytes().to_vec()))
.sys_run_exit(
RunExitResult::TerminalFailure(value.into()),
RetryPolicy::None,
)
.map(Into::into)
.map_err(Into::into)
}

fn sys_run_exit_failure(
fn sys_run_exit_failure_transient(
mut self_: PyRefMut<'_, Self>,
value: PyFailure,
attempt_duration: u64,
config: PyExponentialRetryConfig,
) -> Result<PyAsyncResultHandle, PyVMError> {
self_
.vm
.sys_run_exit(NonEmptyValue::Failure(value.into()))
.sys_run_exit(
RunExitResult::RetryableFailure {
attempt_duration: Duration::from_millis(attempt_duration),
failure: value.into(),
},
config.into(),
)
.map(Into::into)
.map_err(Into::into)
}
Expand All @@ -478,7 +552,7 @@ impl PyVM {
) -> Result<(), PyVMError> {
self_
.vm
.sys_write_output(NonEmptyValue::Success(buffer.as_bytes().to_vec()))
.sys_write_output(NonEmptyValue::Success(buffer.as_bytes().to_vec().into()))
.map(Into::into)
.map_err(Into::into)
}
Expand Down Expand Up @@ -558,6 +632,8 @@ fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PySuspended>()?;
m.add_class::<PyVM>()?;
m.add_class::<PyIdentityVerifier>()?;
m.add_class::<PyExponentialRetryConfig>()?;

m.add("VMException", m.py().get_type_bound::<VMException>())?;
m.add(
"IdentityKeyException",
Expand Down