Skip to content

Use a custom future #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions python/restate/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@
RunAction = Union[Callable[[], T], Callable[[], Awaitable[T]]]


# pylint: disable=R0903
class RestateDurableFuture(typing.Generic[T], Awaitable[T]):
"""
Represents a durable future.
"""

@abc.abstractmethod
def __await__(self):
pass



@dataclass
class Request:
"""
Expand Down Expand Up @@ -110,7 +122,7 @@ def run(self,
action: RunAction[T],
serde: Serde[T] = JsonSerde(),
max_attempts: typing.Optional[int] = None,
max_retry_duration: typing.Optional[timedelta] = None) -> Awaitable[T]:
max_retry_duration: typing.Optional[timedelta] = None) -> RestateDurableFuture[T]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just be aware that this needs additional wiring, so step 4 :)

"""
Runs the given action with the given name.

Expand All @@ -126,7 +138,7 @@ def run(self,
"""

@abc.abstractmethod
def sleep(self, delta: timedelta) -> Awaitable[None]:
def sleep(self, delta: timedelta) -> RestateDurableFuture[None]:
"""
Suspends the current invocation for the given duration
"""
Expand All @@ -135,7 +147,7 @@ def sleep(self, delta: timedelta) -> Awaitable[None]:
def service_call(self,
tpe: Callable[[Any, I], Awaitable[O]],
arg: I,
idempotency_key: str | None = None) -> Awaitable[O]:
idempotency_key: str | None = None) -> RestateDurableFuture[O]:
"""
Invokes the given service with the given argument.
"""
Expand All @@ -158,7 +170,7 @@ def object_call(self,
key: str,
arg: I,
idempotency_key: str | None = None,
) -> Awaitable[O]:
) -> RestateDurableFuture[O]:
"""
Invokes the given object with the given argument.
"""
Expand All @@ -181,7 +193,7 @@ def workflow_call(self,
key: str,
arg: I,
idempotency_key: str | None = None,
) -> Awaitable[O]:
) -> RestateDurableFuture[O]:
"""
Invokes the given workflow with the given argument.
"""
Expand All @@ -205,7 +217,7 @@ def generic_call(self,
handler: str,
arg: bytes,
key: Optional[str] = None,
idempotency_key: str | None = None) -> Awaitable[bytes]:
idempotency_key: str | None = None) -> RestateDurableFuture[bytes]:
"""
Invokes the given generic service/handler with the given argument.
"""
Expand All @@ -225,7 +237,7 @@ def generic_send(self,

@abc.abstractmethod
def awakeable(self,
serde: Serde[T] = JsonSerde()) -> typing.Tuple[str, Awaitable[Any]]:
serde: Serde[T] = JsonSerde()) -> typing.Tuple[str, RestateDurableFuture[Any]]:
"""
Returns the name of the awakeable and the future to be awaited.
"""
Expand All @@ -252,7 +264,7 @@ def cancel(self, invocation_id: str):
"""

@abc.abstractmethod
def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> T:
def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[T]:
"""
Attaches the invocation with the given id.
"""
Expand Down Expand Up @@ -282,13 +294,13 @@ def key(self) -> str:
@abc.abstractmethod
def get(self,
name: str,
serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[Any]]:
serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[Optional[Any]]:
"""
Retrieves the value associated with the given name.
"""

@abc.abstractmethod
def state_keys(self) -> Awaitable[List[str]]:
def state_keys(self) -> RestateDurableFuture[List[str]]:
"""
Returns the list of keys in the store.
"""
Expand Down
114 changes: 59 additions & 55 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import typing
import traceback

from restate.context import DurablePromise, ObjectContext, Request, SendHandle
from restate.context import DurablePromise, ObjectContext, Request, RestateDurableFuture, SendHandle
from restate.exceptions import TerminalError
from restate.handler import Handler, handler_from_callable, invoke_handler
from restate.serde import BytesSerde, JsonSerde, Serde
Expand All @@ -39,6 +39,30 @@
# disable too few public methods
# pylint: disable=R0903

# pylint: disable=W0511


class ServerDurableFuture(RestateDurableFuture[T]):
"""This class implements a durable future API"""
value: T | None = None
metadata: Dict[str, Any] | None = None

def __init__(self, handle: int, factory) -> None:
super().__init__()
self.factory = factory
self.handle = handle

def with_metadata(self, **metadata) -> 'ServerDurableFuture':
"""Add metadata to the future."""
self.metadata = metadata
return self

def __await__(self):
print("..........Awaiting............", flush=True)
task = asyncio.create_task(self.factory())
return task.__await__()


class ServerSendHandle(SendHandle):
"""This class implements the send API"""
_invocation_id: typing.Optional[str]
Expand All @@ -57,6 +81,8 @@ async def invocation_id(self) -> str:
self._invocation_id = res
return res



async def async_value(n: Callable[[], T]) -> T:
"""convert a simple value to a coroutine."""
return n()
Expand Down Expand Up @@ -210,6 +236,7 @@ def must_take_notification(self, handle):
raise TerminalError(res.message, res.code)
return res


async def create_poll_or_cancel_coroutine(self, handle) -> bytes | None:
"""Create a coroutine to poll the handle."""
await self.take_and_send_output()
Expand Down Expand Up @@ -240,20 +267,25 @@ async def create_poll_or_cancel_coroutine(self, handle) -> bytes | None:
await self.take_and_send_output()


def get(self, name: str, serde: Serde[T] = JsonSerde()) -> typing.Awaitable[Optional[Any]]:
coro = self.create_poll_or_cancel_coroutine(self.vm.sys_get_state(name))
def create_df(self, handle: int, serde: Serde[T] | None = None) -> ServerDurableFuture[T]:
"""Create a durable future."""

async def await_point():
"""Wait for this handle to be resolved."""
res = await coro
if res is None:
return None
async def transform():
res = await self.create_poll_or_cancel_coroutine(handle)
if res is None or serde is None:
return res
return serde.deserialize(res)

return await_point() # do not await here, the caller will do it.
return ServerDurableFuture(handle, transform)



def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]]:
handle = self.vm.sys_get_state(name)
return self.create_df(handle, serde) # type: ignore

def state_keys(self) -> Awaitable[List[str]]:
return self.create_poll_or_cancel_coroutine(self.vm.sys_get_state_keys()) # type: ignore
return self.create_df(self.vm.sys_get_state_keys()) # type: ignore

def set(self, name: str, value: T, serde: Serde[T] = JsonSerde()) -> None:
"""Set the value associated with the given name."""
Expand Down Expand Up @@ -309,29 +341,22 @@ def run(self,
action: Callable[[], T] | Callable[[], Awaitable[T]],
serde: Optional[Serde[T]] = JsonSerde(),
max_attempts: Optional[int] = None,
max_retry_duration: Optional[timedelta] = None) -> Awaitable[T]:
max_retry_duration: Optional[timedelta] = None) -> RestateDurableFuture[T]:
assert serde is not None
handle = self.vm.sys_run(name)

# Register closure to run
# TODO: use thunk to avoid coro leak warning.
self.run_coros_to_execute[handle] = self.create_run_coroutine(handle, action, serde, max_attempts, max_retry_duration)

# Prepare response coroutine
coro = self.create_poll_or_cancel_coroutine(handle)
async def await_point():
"""Wait for this handle to be resolved."""
res = await coro
if res is None:
return None
return serde.deserialize(res)
return self.create_df(handle, serde) # type: ignore

return await_point() # do not await here, the caller will do it.


def sleep(self, delta: timedelta) -> Awaitable[None]:
def sleep(self, delta: timedelta) -> RestateDurableFuture[None]:
# convert timedelta to milliseconds
millis = int(delta.total_seconds() * 1000)
return self.create_poll_or_cancel_coroutine(self.vm.sys_sleep(millis)) # type: ignore
return self.create_df(self.vm.sys_sleep(millis)) # type: ignore

def do_call(self,
tpe: Callable[[Any, I], Awaitable[O]],
Expand All @@ -341,7 +366,7 @@ def do_call(self,
send: bool = False,
idempotency_key: str | None = None,
headers: typing.List[typing.Tuple[str, str]] | None = None
) -> Awaitable[O] | SendHandle:
) -> RestateDurableFuture[O] | SendHandle:
"""Make an RPC call to the given handler"""
target_handler = handler_from_callable(tpe)
service=target_handler.service_tag.name
Expand All @@ -362,7 +387,7 @@ def do_raw_call(self,
send: bool = False,
idempotency_key: str | None = None,
headers: typing.List[typing.Tuple[str, str]] | None = None
) -> Awaitable[O] | SendHandle:
) -> RestateDurableFuture[O] | SendHandle:
"""Make an RPC call to the given handler"""
parameter = input_serde.serialize(input_param)
if send_delay:
Expand All @@ -380,19 +405,15 @@ def do_raw_call(self,
idempotency_key=idempotency_key,
headers=headers)

async def await_point(s: ServerInvocationContext, h, o: Serde[O]):
"""Wait for this handle to be resolved, and deserialize the response."""
res = await s.create_poll_or_cancel_coroutine(h)
return o.deserialize(res) # type: ignore

return await_point(self, handle.result_handle, output_serde)
# TODO: specialize this future for calls!
return self.create_df(handle=handle.result_handle, serde=output_serde).with_metadata(invocation_id=handle.invocation_id_handle)

def service_call(self,
tpe: Callable[[Any, I], Awaitable[O]],
arg: I,
idempotency_key: str | None = None,
headers: typing.List[typing.Tuple[str, str]] | None = None
) -> Awaitable[O]:
) -> RestateDurableFuture[O]:
coro = self.do_call(tpe, arg, idempotency_key=idempotency_key, headers=headers)
assert not isinstance(coro, SendHandle)
return coro
Expand All @@ -408,7 +429,7 @@ def object_call(self,
arg: I,
idempotency_key: str | None = None,
headers: typing.List[typing.Tuple[str, str]] | None = None
) -> Awaitable[O]:
) -> RestateDurableFuture[O]:
coro = self.do_call(tpe, arg, key, idempotency_key=idempotency_key, headers=headers)
assert not isinstance(coro, SendHandle)
return coro
Expand All @@ -424,15 +445,15 @@ def workflow_call(self,
arg: I,
idempotency_key: str | None = None,
headers: typing.List[typing.Tuple[str, str]] | None = None
) -> Awaitable[O]:
) -> RestateDurableFuture[O]:
return self.object_call(tpe, key, arg, idempotency_key=idempotency_key, headers=headers)

def workflow_send(self, tpe: Callable[[Any, I], Awaitable[O]], key: str, arg: I, send_delay: timedelta | None = None, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None) -> SendHandle:
send = self.object_send(tpe, key, arg, send_delay, idempotency_key=idempotency_key, headers=headers)
assert isinstance(send, SendHandle)
return send

def generic_call(self, service: str, handler: str, arg: bytes, key: str | None = None, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None) -> Awaitable[bytes]:
def generic_call(self, service: str, handler: str, arg: bytes, key: str | None = None, idempotency_key: str | None = None, headers: typing.List[typing.Tuple[str, str]] | None = None) -> RestateDurableFuture[bytes]:
serde = BytesSerde()
call_handle = self.do_raw_call(service=service,
handler=handler,
Expand Down Expand Up @@ -461,19 +482,10 @@ def generic_send(self, service: str, handler: str, arg: bytes, key: str | None =
return send_handle

def awakeable(self,
serde: typing.Optional[Serde[I]] = JsonSerde()) -> typing.Tuple[str, Awaitable[Any]]:
serde: typing.Optional[Serde[I]] = JsonSerde()) -> typing.Tuple[str, RestateDurableFuture[Any]]:
assert serde is not None
name, handle = self.vm.sys_awakeable()
coro = self.create_poll_or_cancel_coroutine(handle)

async def await_point():
"""Wait for this handle to be resolved."""
res = await coro
assert res is not None
return serde.deserialize(res)


return name, await_point()
return name, self.create_df(handle, serde)

def resolve_awakeable(self,
name: str,
Expand All @@ -499,17 +511,9 @@ def cancel(self, invocation_id: str):
raise ValueError("invocation_id cannot be None")
self.vm.sys_cancel(invocation_id)

def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> T:
def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> RestateDurableFuture[T]:
if invocation_id is None:
raise ValueError("invocation_id cannot be None")
assert serde is not None
handle = self.vm.attach_invocation(invocation_id)
coro = self.create_poll_or_cancel_coroutine(handle)

async def await_point():
"""Wait for this handle to be resolved."""
res = await coro
assert res is not None
return serde.deserialize(res)

return await_point()
return self.create_df(handle, serde)