Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def payment_gateway():
print("To decline use:")
print(f"""curl http://localhost:8080/payment/{workflow_key}/payment_verified --json '"declined"' """)

await ctx.run("payment", payment_gateway)
await ctx.run_typed("payment", payment_gateway)

ctx.set("status", "waiting for the payment provider to approve")

Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ license = { file = "LICENSE" }
authors = [
{ name = "Restate Developers", email = "dev@restate.dev" }
]
dependencies = [
"typing-extensions>=4.14.0"
]


[project.optional-dependencies]
test = ["pytest", "hypercorn"]
Expand Down
3 changes: 2 additions & 1 deletion python/restate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .context import Context, ObjectContext, ObjectSharedContext
from .context import WorkflowContext, WorkflowSharedContext
# pylint: disable=line-too-long
from .context import DurablePromise, RestateDurableFuture, RestateDurableCallFuture, RestateDurableSleepFuture, SendHandle
from .context import DurablePromise, RestateDurableFuture, RestateDurableCallFuture, RestateDurableSleepFuture, SendHandle, RunOptions
from .exceptions import TerminalError
from .asyncio import as_completed, gather, wait_completed, select

Expand Down Expand Up @@ -50,6 +50,7 @@ def test_harness(app, follow_logs = False, restate_image = ""): # type: ignore
"RestateDurableCallFuture",
"RestateDurableSleepFuture",
"SendHandle",
"RunOptions",
"TerminalError",
"app",
"test_harness",
Expand Down
97 changes: 95 additions & 2 deletions python/restate/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,39 @@

import abc
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, Coroutine, overload
from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, Coroutine, overload, ParamSpec
import typing
from datetime import timedelta

import typing_extensions
from restate.serde import DefaultSerde, Serde

T = TypeVar('T')
I = TypeVar('I')
O = TypeVar('O')
P = ParamSpec('P')

RunAction = Union[Callable[..., Coroutine[Any, Any, T]], Callable[..., T]]
HandlerType = Union[Callable[[Any, I], Awaitable[O]], Callable[[Any], Awaitable[O]]]
RunAction = Union[Callable[..., Coroutine[Any, Any, T]], Callable[..., T]]

@dataclass
class RunOptions(typing.Generic[T]):
"""
Options for running an action.
"""

serde: Serde[T] = DefaultSerde()
"""The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type.
See also 'type_hint'."""
max_attempts: Optional[int] = None
"""The maximum number of retry attempts to complete the action.
If None, the action will be retried indefinitely, until it succeeds.
Otherwise, the action will be retried until the maximum number of attempts is reached and then it will raise a TerminalError."""
max_retry_duration: Optional[timedelta] = None
"""The maximum duration for retrying. If None, the action will be retried indefinitely, until it succeeds.
Otherwise, the action will be retried until the maximum duration is reached and then it will raise a TerminalError."""
type_hint: Optional[typing.Type[T]] = None
"""The type hint of the return value of the action. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer."""

# pylint: disable=R0903
class RestateDurableFuture(typing.Generic[T], Awaitable[T]):
Expand Down Expand Up @@ -197,6 +219,8 @@ def request(self) -> Request:
Returns the request object.
"""


@typing_extensions.deprecated("`run` is deprecated, use `run_typed` instead for better type safety")
@overload
@abc.abstractmethod
def run(self,
Expand Down Expand Up @@ -226,6 +250,7 @@ def run(self,

"""

@typing_extensions.deprecated("`run` is deprecated, use `run_typed` instead for better type safety")
@overload
@abc.abstractmethod
def run(self,
Expand Down Expand Up @@ -255,6 +280,7 @@ def run(self,

"""

@typing_extensions.deprecated("`run` is deprecated, use `run_typed` instead for better type safety")
@abc.abstractmethod
def run(self,
name: str,
Expand Down Expand Up @@ -283,6 +309,73 @@ def run(self,

"""


@overload
@abc.abstractmethod
def run_typed(self,
name: str,
action: Callable[P, Coroutine[Any, Any,T]],
options: RunOptions[T] = RunOptions(),
/,
*args: P.args,
**kwargs: P.kwargs,
) -> RestateDurableFuture[T]:
"""
Typed version of run that provides type hints for the function arguments.
Runs the given action with the given name.

Args:
name: The name of the action.
action: The action to run.
options: The options for the run.
*args: The arguments to pass to the action.
**kwargs: The keyword arguments to pass to the action.
"""

@overload
@abc.abstractmethod
def run_typed(self,
name: str,
action: Callable[P, T],
options: RunOptions[T] = RunOptions(),
/,
*args: P.args,
**kwargs: P.kwargs,
) -> RestateDurableFuture[T]:
"""
Typed version of run that provides type hints for the function arguments.
Runs the given coroutine action with the given name.

Args:
name: The name of the action.
action: The action to run.
options: The options for the run.
*args: The arguments to pass to the action.
**kwargs: The keyword arguments to pass to the action.
"""

@abc.abstractmethod
def run_typed(self,
name: str,
action: Union[Callable[P, Coroutine[Any, Any, T]], Callable[P, T]],
options: RunOptions[T] = RunOptions(),
/,
*args: P.args,
**kwargs: P.kwargs,
) -> RestateDurableFuture[T]:
"""
Typed version of run that provides type hints for the function arguments.
Runs the given action with the given name.

Args:
name: The name of the action.
action: The action to run.
options: The options for the run.
*args: The arguments to pass to the action.
**kwargs: The keyword arguments to pass to the action.

"""

@abc.abstractmethod
def sleep(self, delta: timedelta) -> RestateDurableSleepFuture:
"""
Expand Down
25 changes: 23 additions & 2 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,18 @@
from datetime import timedelta
import inspect
import functools
from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar
from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union, Coroutine
import typing
import traceback

from restate.context import DurablePromise, AttemptFinishedEvent, HandlerType, ObjectContext, Request, RestateDurableCallFuture, RestateDurableFuture, RunAction, SendHandle, RestateDurableSleepFuture
from restate.context import DurablePromise, AttemptFinishedEvent, HandlerType, ObjectContext, Request, RestateDurableCallFuture, RestateDurableFuture, RunAction, SendHandle, RestateDurableSleepFuture, RunOptions, P
from restate.exceptions import TerminalError
from restate.handler import Handler, handler_from_callable, invoke_handler
from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde
from restate.server_types import ReceiveChannel, Send
from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig # pylint: disable=line-too-long
from restate.vm import DoProgressAnyCompleted, DoProgressCancelSignalReceived, DoProgressReadFromInput, DoProgressExecuteRun, DoWaitPendingRun
import typing_extensions


T = TypeVar('T')
Expand Down Expand Up @@ -510,6 +511,7 @@ async def create_run_coroutine(self,
self.vm.propose_run_completion_transient(handle, failure=failure, attempt_duration_ms=1, config=config)
# pylint: disable=W0236
# pylint: disable=R0914
@typing_extensions.deprecated("`run` is deprecated, use `run_typed` instead for better type safety")
def run(self,
name: str,
action: RunAction[T],
Expand All @@ -536,6 +538,25 @@ def run(self,
self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, noargs_action, serde, max_attempts, max_retry_duration)
return self.create_future(handle, serde) # type: ignore

def run_typed(
self,
name: str,
action: Union[Callable[P, T], Callable[P, Coroutine[Any, Any, T]]],
options: RunOptions[T] = RunOptions(),
/,
*args: P.args,
**kwargs: P.kwargs,
) -> RestateDurableFuture[T]:
if isinstance(options.serde, DefaultSerde):
if options.type_hint is None:
signature = inspect.signature(action, eval_str=True)
options.type_hint = signature.return_annotation
options.serde = options.serde.with_maybe_type(options.type_hint)
handle = self.vm.sys_run(name)

func = functools.partial(action, *args, **kwargs)
self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, func, options.serde, options.max_attempts, options.max_retry_duration)
return self.create_future(handle, options.serde)

def sleep(self, delta: timedelta) -> RestateDurableSleepFuture:
# convert timedelta to milliseconds
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ pytest
pydantic
httpx
testcontainers
typing-extensions>=4.14.0
10 changes: 6 additions & 4 deletions test-services/services/failing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from restate import VirtualObject, ObjectContext
from restate.exceptions import TerminalError
from restate import RunOptions

failing = VirtualObject("Failing")

Expand Down Expand Up @@ -45,23 +46,23 @@ async def terminally_failing_side_effect(ctx: ObjectContext, error_message: str)
def side_effect():
raise TerminalError(message=error_message)

await ctx.run("sideEffect", side_effect)
await ctx.run_typed("sideEffect", side_effect)
raise ValueError("Should not reach here")


eventual_success_side_effects = 0

@failing.handler(name="sideEffectSucceedsAfterGivenAttempts")
async def side_effect_succeeds_after_given_attempts(ctx: ObjectContext, minimum_attempts: int) -> int:

def side_effect():
global eventual_success_side_effects
eventual_success_side_effects += 1
if eventual_success_side_effects >= minimum_attempts:
return eventual_success_side_effects
raise ValueError(f"Failed at attempt: {eventual_success_side_effects}")

return await ctx.run("sideEffect", side_effect, max_attempts=minimum_attempts + 1) # type: ignore
options: RunOptions[int] = RunOptions(max_attempts=minimum_attempts + 1)
return await ctx.run_typed("sideEffect", side_effect, options)

eventual_failure_side_effects = 0

Expand All @@ -74,7 +75,8 @@ def side_effect():
raise ValueError(f"Failed at attempt: {eventual_failure_side_effects}")

try:
await ctx.run("sideEffect", side_effect, max_attempts=retry_policy_max_retry_count)
options: RunOptions[int] = RunOptions(max_attempts=retry_policy_max_retry_count)
await ctx.run_typed("sideEffect", side_effect, options)
raise ValueError("Side effect did not fail.")
except TerminalError as t:
global eventual_failure_side_effects
Expand Down
4 changes: 2 additions & 2 deletions test-services/services/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ async def await_promise(index: int) -> None:
coros[i] = (expected, service.echo_later(expected, command['sleep']))
elif command_type == SIDE_EFFECT:
expected = f"hello-{i}"
result = await ctx.run("sideEffect", lambda : expected) # pylint: disable=W0640
result = await ctx.run_typed("sideEffect", lambda: expected)
if result != expected:
raise TerminalError(f"Expected {expected} but got {result}")
elif command_type == SLOW_SIDE_EFFECT:
Expand All @@ -246,7 +246,7 @@ async def side_effect():
if bool(random.getrandbits(1)):
raise ValueError("Random error")

await ctx.run("throwingSideEffect", side_effect)
await ctx.run_typed("throwingSideEffect", side_effect)
elif command_type == INCREMENT_STATE_COUNTER_INDIRECTLY:
await service.increment_indirectly(layer=layer, key=ctx.key())
elif command_type == AWAIT_PROMISE:
Expand Down
8 changes: 5 additions & 3 deletions test-services/services/virtual_object_command_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def to_durable_future(ctx: ObjectContext, cmd: AwaitableCommand) -> RestateDurab
elif cmd['type'] == "sleep":
return ctx.sleep(timedelta(milliseconds=cmd['timeoutMillis']))
elif cmd['type'] == "runThrowTerminalException":
def side_effect(reason):
def side_effect(reason: str):
raise TerminalError(message=reason)
res = ctx.run("run should fail command", side_effect, args=(cmd['reason'],))
res = ctx.run_typed("run should fail command", side_effect, reason=cmd['reason'])
return res

@virtual_object_command_interpreter.handler(name="interpretCommands")
Expand All @@ -142,7 +142,9 @@ async def interpret_commands(ctx: ObjectContext, req: InterpretRequest):
result = ""
elif cmd['type'] == "getEnvVariable":
env_name = cmd['envName']
result = await ctx.run("get_env", lambda e=env_name: os.environ.get(e, ""))
def side_effect(env_name: str):
return os.environ.get(env_name, "")
result = await ctx.run_typed("get_env", side_effect, env_name=env_name)
elif cmd['type'] == "awaitOne":
awaitable = to_durable_future(ctx, cmd['command'])
# We need this dance because the Python SDK doesn't support .map on futures
Expand Down