Skip to content

Add additional features #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion python/restate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
# types
from .context import Context, ObjectContext, ObjectSharedContext
from .context import WorkflowContext, WorkflowSharedContext
from .context import DurablePromise
from .context import DurablePromise, RestateDurableFuture, RestateDurableCallFuture, SendHandle
from .combinators import wait, gather, as_completed, ALL_COMPLETED, FIRST_COMPLETED
from .exceptions import TerminalError

from .endpoint import app

Expand All @@ -43,6 +45,15 @@ def test_harness(app, follow_logs = False, restate_image = ""): # type: ignore
"WorkflowContext",
"WorkflowSharedContext",
"DurablePromise",
"RestateDurableFuture",
"RestateDurableCallFuture",
"SendHandle",
"TerminalError",
"app",
"test_harness",
"wait",
"gather",
"as_completed",
"ALL_COMPLETED",
"FIRST_COMPLETED",
]
119 changes: 119 additions & 0 deletions python/restate/combinators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#
# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
# pylint: disable=R0913,C0301,R0917
# pylint: disable=line-too-long
"""combines multiple futures into a single future"""

from typing import Any, List, Tuple
from restate.exceptions import TerminalError
from restate.context import RestateDurableFuture
from restate.server_context import ServerDurableFuture, ServerInvocationContext

FIRST_COMPLETED = 1
ALL_COMPLETED = 2

async def wait(*futures: RestateDurableFuture[Any], mode: int = FIRST_COMPLETED) -> Tuple[List[RestateDurableFuture[Any]], List[RestateDurableFuture[Any]]]:
"""
Blocks until at least one of the futures/all of the futures are completed.

Returns a tuple of two lists: the first list contains the futures that are completed,
the second list contains the futures that are not completed.

The mode parameter can be either FIRST_COMPLETED or ALL_COMPLETED.
Using FIRST_COMPLETED will return as soon as one of the futures is completed.
Using ALL_COMPLETED will return only when all futures are completed.

examples:

completed, waiting = await wait(f1, f2, f3, mode=FIRST_COMPLETED)
for completed_future in completed:
# do something with the completed future
print(await completed_future) # prints the result of the future

or

completed, waiting = await wait(f1, f2, f3, mode=ALL_COMPLETED)
assert waiting == []


"""
assert mode in (FIRST_COMPLETED, ALL_COMPLETED)

remaining = list(futures)
while remaining:
completed, waiting = await wait_completed(remaining)
if mode == FIRST_COMPLETED:
return completed, waiting
remaining = waiting

assert mode == ALL_COMPLETED
return list(futures), []

async def gather(*futures: RestateDurableFuture[Any]) -> List[RestateDurableFuture[Any]]:
"""
Blocks until all futures are completed.

Returns a list of all futures.
"""
completed, _ = await wait(*futures, mode=ALL_COMPLETED)
return completed

async def as_completed(*futures: RestateDurableFuture[Any]):
"""
Returns an iterator that yields the futures as they are completed.

example:

async for future in as_completed(f1, f2, f3):
# do something with the completed future
print(await future) # prints the result of the future

"""
remaining = list(futures)
while remaining:
completed, waiting = await wait_completed(remaining)
for f in completed:
yield f
remaining = waiting

async def wait_completed(futures: List[RestateDurableFuture[Any]]) -> Tuple[List[RestateDurableFuture[Any]], List[RestateDurableFuture[Any]]]:
"""
Blocks until at least one of the futures is completed.

Returns a tuple of two lists: the first list contains the futures that are completed,
the second list contains the futures that are not completed.
"""
if not futures:
return [], []
handles: List[int] = []
context: ServerInvocationContext | None = None
for f in futures:
if not isinstance(f, ServerDurableFuture):
raise TerminalError("All futures must SDK created futures.")
if context is None:
context = f.context
elif context is not f.context:
raise TerminalError("All futures must be created by the same SDK context.")
if f.is_completed():
return [f], []
handles.append(f.source_notification_handle)

assert context is not None
await context.create_poll_or_cancel_coroutine(handles)
completed = []
uncompleted = []
for index, handle in enumerate(handles):
future = futures[index]
if context.vm.is_completed(handle):
completed.append(future)
else:
uncompleted.append(future)
return completed, uncompleted
6 changes: 6 additions & 0 deletions python/restate/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ class RestateDurableFuture(typing.Generic[T], Awaitable[T]):
Represents a durable future.
"""

@abc.abstractmethod
def is_completed(self) -> bool:
"""
Returns True if the future is completed, False otherwise.
"""

@abc.abstractmethod
def __await__(self):
pass
Expand Down
113 changes: 75 additions & 38 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,42 +36,73 @@
O = TypeVar('O')





class ServerDurableFuture(RestateDurableFuture[T]):
"""This class implements a durable future API"""
value: T | None = None
error: TerminalError | None = None
state: typing.Literal["pending", "fulfilled", "rejected"] = "pending"

def __init__(self, handle: int, factory) -> None:
def __init__(self, context: "ServerInvocationContext", handle: int, awaitable_factory) -> None:
super().__init__()
self.factory = factory
self.handle = handle
self.context = context
self.source_notification_handle = handle
self.awaitable_factory = awaitable_factory
self.state = "pending"


def is_completed(self):
match self.state:
case "pending":
return self.context.vm.is_completed(self.source_notification_handle)
case "fulfilled":
return True
case "rejected":
return True


def __await__(self):
task = asyncio.create_task(self.factory())
return task.__await__()

async def await_point():
match self.state:
case "pending":
try:
self.value = await self.awaitable_factory()
self.state = "fulfilled"
return self.value
except TerminalError as t:
self.error = t
self.state = "rejected"
raise t
case "fulfilled":
return self.value
case "rejected":
assert self.error is not None
raise self.error


return await_point().__await__()

class ServerCallDurableFuture(RestateDurableCallFuture[T], ServerDurableFuture[T]):
"""This class implements a durable future but for calls"""
_invocation_id: typing.Optional[str] = None

def __init__(self, result_handle: int,
def __init__(self,
context: "ServerInvocationContext",
result_handle: int,
result_factory,
invocation_id_handle: int,
invocation_id_factory) -> None:
super().__init__(result_handle, result_factory)
super().__init__(context, result_handle, result_factory)
self.invocation_id_handle = invocation_id_handle
self.invocation_id_factory = invocation_id_factory


async def invocation_id(self) -> str:
"""Get the invocation id."""
if self._invocation_id is None:
self._invocation_id = await self.invocation_id_factory()
return self._invocation_id


class ServerSendHandle(SendHandle):
"""This class implements the send API"""
_invocation_id: typing.Optional[str]
Expand All @@ -90,29 +121,29 @@ async def invocation_id(self) -> str:
self._invocation_id = res
return res



async def async_value(n: Callable[[], T]) -> T:
"""convert a simple value to a coroutine."""
return n()


class ServerDurablePromise(DurablePromise):
"""This class implements a durable promise API"""

def __init__(self, server_context, name, serde) -> None:
def __init__(self, server_context: "ServerInvocationContext", name, serde) -> None:
super().__init__(name=name, serde=JsonSerde() if serde is None else serde)
self.server_context = server_context

def value(self) -> Awaitable[Any]:
vm: VMWrapper = self.server_context.vm
handle = vm.sys_get_promise(self.name)
coro = self.server_context.create_poll_or_cancel_coroutine(handle)
coro = self.server_context.create_poll_or_cancel_coroutine([handle])
serde = self.serde
assert serde is not None

async def await_point():
res = await coro
await coro
res = self.server_context.must_take_notification(handle)
if res is None:
return None
return serde.deserialize(res)

return await_point()
Expand All @@ -122,23 +153,33 @@ def resolve(self, value: Any) -> Awaitable[None]:
assert self.serde is not None
value_buffer = self.serde.serialize(value)
handle = vm.sys_complete_promise_success(self.name, value_buffer)
return self.server_context.create_poll_or_cancel_coroutine(handle)

async def await_point():
await self.server_context.create_poll_or_cancel_coroutine([handle])
self.server_context.must_take_notification(handle)

return await_point()

def reject(self, message: str, code: int = 500) -> Awaitable[None]:
vm: VMWrapper = self.server_context.vm
py_failure = Failure(code=code, message=message)
handle = vm.sys_complete_promise_failure(self.name, py_failure)
return self.server_context.create_poll_or_cancel_coroutine(handle)

async def await_point():
await self.server_context.create_poll_or_cancel_coroutine([handle])
self.server_context.must_take_notification(handle)

return await_point()

def peek(self) -> Awaitable[Any | None]:
vm: VMWrapper = self.server_context.vm
handle = vm.sys_peek_promise(self.name)
coro = self.server_context.create_poll_or_cancel_coroutine(handle)
serde = self.serde
assert serde is not None

async def await_point():
res = await coro
await self.server_context.create_poll_or_cancel_coroutine([handle])
res = self.server_context.must_take_notification(handle)
if res is None:
return None
return serde.deserialize(res)
Expand Down Expand Up @@ -246,24 +287,17 @@ def must_take_notification(self, handle):
return res


async def create_poll_or_cancel_coroutine(self, handle) -> bytes | None:
async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> None:
"""Create a coroutine to poll the handle."""
await self.take_and_send_output()
while True:
if self.vm.is_completed(handle):
# Handle is completed
return self.must_take_notification(handle)

# Nothing ready yet, let's try to make some progress
do_progress_response = self.vm.do_progress([handle])
do_progress_response = self.vm.do_progress(handles)
if isinstance(do_progress_response, DoProgressAnyCompleted):
# One of the handles completed, we can continue
continue
# One of the handles completed
return
if isinstance(do_progress_response, DoProgressCancelSignalReceived):
# Raise cancel signal
raise TerminalError("cancelled", 409)
if isinstance(do_progress_response, DoProgressReadFromInput):
# We need to read from input
chunk = await self.receive()
if chunk.get('body', None) is not None:
assert isinstance(chunk['body'], bytes)
Expand All @@ -280,28 +314,31 @@ def create_df(self, handle: int, serde: Serde[T] | None = None) -> ServerDurable
"""Create a durable future."""

async def transform():
res = await self.create_poll_or_cancel_coroutine(handle)
await self.create_poll_or_cancel_coroutine([handle])
res = self.must_take_notification(handle)
if res is None or serde is None:
return res
return serde.deserialize(res)

return ServerDurableFuture(handle, transform)
return ServerDurableFuture(self, handle, transform)



def create_call_df(self, handle: int, invocation_id_handle: int, serde: Serde[T] | None = None) -> ServerCallDurableFuture[T]:
"""Create a durable future."""

async def transform():
res = await self.create_poll_or_cancel_coroutine(handle)
await self.create_poll_or_cancel_coroutine([handle])
res = self.must_take_notification(handle)
if res is None or serde is None:
return res
return serde.deserialize(res)

def inv_id_factory():
return self.create_poll_or_cancel_coroutine(invocation_id_handle)
async def inv_id_factory():
await self.create_poll_or_cancel_coroutine([invocation_id_handle])
return self.must_take_notification(invocation_id_handle)

return ServerCallDurableFuture(handle, transform, invocation_id_handle, inv_id_factory)
return ServerCallDurableFuture(self, handle, transform, invocation_id_handle, inv_id_factory)


def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]]:
Expand Down