Skip to content

Support ctx.run combinators #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions python/restate/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,13 @@ async def wait_completed(futures: List[RestateDurableFuture[Any]]) -> Tuple[List
Returns a tuple of two lists: the first list contains the futures that are completed,
the second list contains the futures that are not completed.
"""
if not futures:
return [], []
handles: List[int] = []
context: ServerInvocationContext | None = None
completed = []
uncompleted = []

if not futures:
return [], []
for f in futures:
if not isinstance(f, ServerDurableFuture):
raise TerminalError("All futures must SDK created futures.")
Expand All @@ -103,17 +106,24 @@ async def wait_completed(futures: List[RestateDurableFuture[Any]]) -> Tuple[List
elif context is not f.context:
raise TerminalError("All futures must be created by the same SDK context.")
if f.is_completed():
return [f], []
handles.append(f.source_notification_handle)
completed.append(f)
else:
handles.append(f.source_notification_handle)
uncompleted.append(f)

if completed:
# the user had passed some completed futures, so we can return them immediately
return completed, uncompleted # type: ignore

assert context is not None
await context.create_poll_or_cancel_coroutine(handles)
completed = []
uncompleted = []
assert context is not None
await context.create_poll_or_cancel_coroutine(handles)

for index, handle in enumerate(handles):
future = futures[index]
if context.vm.is_completed(handle):
completed.append(future)
completed.append(future) # type: ignore
else:
uncompleted.append(future)
return completed, uncompleted
uncompleted.append(future) # type: ignore
return completed, uncompleted # type: ignore
6 changes: 6 additions & 0 deletions python/restate/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def is_completed(self) -> bool:
def __await__(self):
pass

@abc.abstractmethod
def map_value(self, mapper: Callable[[T], O]) -> 'RestateDurableFuture[O]':
"""
Maps the value of the future using the given function.
"""


# pylint: disable=R0903
class RestateDurableCallFuture(RestateDurableFuture[T]):
Expand Down
43 changes: 39 additions & 4 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde
from restate.server_types import Receive, Send
from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig # pylint: disable=line-too-long
from restate.vm import DoProgressAnyCompleted, DoProgressCancelSignalReceived, DoProgressReadFromInput, DoProgressExecuteRun # pylint: disable=line-too-long
from restate.vm import DoProgressAnyCompleted, DoProgressCancelSignalReceived, DoProgressReadFromInput, DoProgressExecuteRun, DoWaitPendingRun # pylint: disable=line-too-long

T = TypeVar('T')
I = TypeVar('I')
Expand Down Expand Up @@ -59,6 +59,13 @@ def is_completed(self):
case "rejected":
return True

def map_value(self, mapper: Callable[[T], O]) -> RestateDurableFuture[O]:
"""Map the value of the future."""
async def mapper_coro():
return mapper(await self)

return ServerDurableFuture(self.context, self.source_notification_handle, mapper_coro)


def __await__(self):

Expand Down Expand Up @@ -190,6 +197,25 @@ async def await_point():
# disable too many public method
# pylint: disable=R0904

class SyncPoint:
"""
This class implements a synchronization point.
"""

def __init__(self):
self._cond = asyncio.Condition()

async def wait(self):
"""Wait for the sync point."""
async with self._cond:
await self._cond.wait()

async def arrive(self):
"""Arrive at the sync point."""
async with self._cond:
self._cond.notify_all()

# pylint: disable=R0902
class ServerInvocationContext(ObjectContext):
"""This class implements the context for the restate framework based on the server."""

Expand All @@ -208,6 +234,7 @@ def __init__(self,
self.send = send
self.receive = receive
self.run_coros_to_execute: dict[int, Callable[[], Awaitable[typing.Union[bytes | Failure]]]] = {}
self.sync_point = SyncPoint()

async def enter(self):
"""Invoke the user code."""
Expand Down Expand Up @@ -306,9 +333,18 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
self.vm.notify_input_closed()
continue
if isinstance(do_progress_response, DoProgressExecuteRun):
await self.run_coros_to_execute[do_progress_response.handle]()
await self.take_and_send_output()
fn = self.run_coros_to_execute[do_progress_response.handle]
assert fn is not None

async def wrapper(f):
await f()
await self.take_and_send_output()
await self.sync_point.arrive()

asyncio.create_task(wrapper(fn))
continue
if isinstance(do_progress_response, DoWaitPendingRun):
await self.sync_point.wait()

def create_df(self, handle: int, serde: Serde[T] | None = None) -> ServerDurableFuture[T]:
"""Create a durable future."""
Expand Down Expand Up @@ -407,7 +443,6 @@ def run(self,
handle = self.vm.sys_run(name)

# Register closure to run
# TODO: use thunk to avoid coro leak warning.
self.run_coros_to_execute[handle] = lambda : self.create_run_coroutine(handle, action, serde, max_attempts, max_retry_duration)

# Prepare response coroutine
Expand Down
13 changes: 11 additions & 2 deletions python/restate/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from dataclasses import dataclass
import typing
from restate._internal import PyVM, PyHeader, PyFailure, PySuspended, PyVoid, PyStateKeys, PyExponentialRetryConfig, PyDoProgressAnyCompleted, PyDoProgressReadFromInput, PyDoProgressExecuteRun, PyDoProgressCancelSignalReceived, CANCEL_NOTIFICATION_HANDLE # pylint: disable=import-error,no-name-in-module,line-too-long
from restate._internal import PyVM, PyHeader, PyFailure, PySuspended, PyVoid, PyStateKeys, PyExponentialRetryConfig, PyDoProgressAnyCompleted, PyDoProgressReadFromInput, PyDoProgressExecuteRun, PyDoWaitForPendingRun, PyDoProgressCancelSignalReceived, CANCEL_NOTIFICATION_HANDLE # pylint: disable=import-error,no-name-in-module,line-too-long

@dataclass
class Invocation:
Expand Down Expand Up @@ -90,14 +90,21 @@ class DoProgressCancelSignalReceived:
Represents a notification that a cancel signal has been received
"""

class DoWaitPendingRun:
"""
Represents a notification that a run is pending
"""

DO_PROGRESS_ANY_COMPLETED = DoProgressAnyCompleted()
DO_PROGRESS_READ_FROM_INPUT = DoProgressReadFromInput()
DO_PROGRESS_CANCEL_SIGNAL_RECEIVED = DoProgressCancelSignalReceived()
DO_WAIT_PENDING_RUN = DoWaitPendingRun()

DoProgressResult = typing.Union[DoProgressAnyCompleted,
DoProgressReadFromInput,
DoProgressExecuteRun,
DoProgressCancelSignalReceived]
DoProgressCancelSignalReceived,
DoWaitPendingRun]


# pylint: disable=too-many-public-methods
Expand Down Expand Up @@ -157,6 +164,8 @@ def do_progress(self, handles: list[int]) -> DoProgressResult:
return DoProgressExecuteRun(result.handle)
if isinstance(result, PyDoProgressCancelSignalReceived):
return DO_PROGRESS_CANCEL_SIGNAL_RECEIVED
if isinstance(result, PyDoWaitForPendingRun):
return DO_WAIT_PENDING_RUN
raise ValueError(f"Unknown progress type: {result}")

def take_notification(self, handle: int) -> NotificationType:
Expand Down
9 changes: 8 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ struct PyDoProgressExecuteRun {
#[pyclass]
struct PyDoProgressCancelSignalReceived;

#[pyclass]
struct PyDoWaitForPendingRun;

#[pyclass]
pub struct PyCallHandle {
#[pyo3(get)]
Expand Down Expand Up @@ -349,7 +352,10 @@ impl PyVM {
.into_py(py)
.into_bound(py)
.into_any()),
Ok(DoProgressResponse::WaitingPendingRun) => panic!("Python SDK doesn't support concurrent pending runs, so this is not supposed to happen")
Ok(DoProgressResponse::WaitingPendingRun) => Ok(PyDoWaitForPendingRun
.into_py(py)
.into_bound(py)
.into_any()),
}
}

Expand Down Expand Up @@ -767,6 +773,7 @@ fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyDoProgressReadFromInput>()?;
m.add_class::<PyDoProgressExecuteRun>()?;
m.add_class::<PyDoProgressCancelSignalReceived>()?;
m.add_class::<PyDoWaitForPendingRun>()?;
m.add_class::<PyCallHandle>()?;

m.add("VMException", m.py().get_type_bound::<VMException>())?;
Expand Down