Skip to content

Wait for the side effect to complete #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions python/restate/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from restate.discovery import compute_discovery_json
from restate.endpoint import Endpoint
from restate.server_context import ServerInvocationContext, DisconnectedException
from restate.server_types import Receive, Scope, Send, binary_to_header, header_to_binary
from restate.server_types import Receive, ReceiveChannel, Scope, Send, binary_to_header, header_to_binary # pylint: disable=line-too-long
from restate.vm import VMWrapper
from restate._internal import PyIdentityVerifier, IdentityVerificationException # pylint: disable=import-error,no-name-in-module
from restate._internal import SDK_VERSION # pylint: disable=import-error,no-name-in-module
Expand Down Expand Up @@ -85,7 +85,7 @@ async def send_health_check(send: Send):
async def process_invocation_to_completion(vm: VMWrapper,
handler,
attempt_headers: Dict[str, str],
receive: Receive,
receive: ReceiveChannel,
send: Send):
"""Invoke the user code."""
status, res_headers = vm.get_response_head()
Expand Down Expand Up @@ -171,6 +171,7 @@ def parse_path(request: str) -> ParsedPath:
# anything other than invoke is 404
return { "type": "unknown" , "service": None, "handler": None }


def asgi_app(endpoint: Endpoint):
"""Create an ASGI-3 app for the given endpoint."""

Expand Down Expand Up @@ -201,7 +202,7 @@ async def app(scope: Scope, receive: Receive, send: Send):
identity_verifier.verify(request_headers, request_path)
except IdentityVerificationException:
# Identify verification failed, send back unauthorized and close
await send_status(send, receive,401)
await send_status(send, receive, 401)
return

# might be a discovery request
Expand All @@ -228,11 +229,15 @@ async def app(scope: Scope, receive: Receive, send: Send):
# At this point we have a valid handler.
# Let us setup restate's execution context for this invocation and handler.
#
await process_invocation_to_completion(VMWrapper(request_headers),
handler,
dict(request_headers),
receive,
send)
receive_channel = ReceiveChannel(receive)
try:
await process_invocation_to_completion(VMWrapper(request_headers),
handler,
dict(request_headers),
receive_channel,
send)
finally:
await receive_channel.close()
except LifeSpanNotImplemented as e:
raise e
except Exception as e:
Expand Down
58 changes: 14 additions & 44 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from restate.exceptions import TerminalError
from restate.handler import Handler, handler_from_callable, invoke_handler
from restate.serde import BytesSerde, DefaultSerde, JsonSerde, Serde
from restate.server_types import Receive, Send
from restate.server_types import ReceiveChannel, Send
from restate.vm import Failure, Invocation, NotReady, SuspendedException, VMWrapper, RunRetryConfig # pylint: disable=line-too-long
from restate.vm import DoProgressAnyCompleted, DoProgressCancelSignalReceived, DoProgressReadFromInput, DoProgressExecuteRun, DoWaitPendingRun

Expand Down Expand Up @@ -220,25 +220,6 @@ def peek(self) -> Awaitable[Any | None]:
# disable too many public method
# pylint: disable=R0904

class SyncPoint:
"""
This class implements a synchronization point.
"""

def __init__(self) -> None:
self.cond: asyncio.Event | None = None

def awaiter(self):
"""Wait for the sync point."""
if self.cond is None:
self.cond = asyncio.Event()
return self.cond.wait()

async def arrive(self):
"""arrive at the sync point."""
if self.cond is not None:
self.cond.set()

class Tasks:
"""
This class implements a list of tasks.
Expand Down Expand Up @@ -284,7 +265,8 @@ def __init__(self,
invocation: Invocation,
attempt_headers: Dict[str, str],
send: Send,
receive: Receive) -> None:
receive: ReceiveChannel
) -> None:
super().__init__()
self.vm = vm
self.handler = handler
Expand All @@ -293,7 +275,6 @@ def __init__(self,
self.send = send
self.receive = receive
self.run_coros_to_execute: dict[int, Callable[[], Awaitable[None]]] = {}
self.sync_point = SyncPoint()
self.request_finished_event = asyncio.Event()
self.tasks = Tasks()

Expand Down Expand Up @@ -365,18 +346,6 @@ def on_attempt_finished(self):
# ignore the cancelled error
pass


async def receive_and_notify_input(self):
"""Receive input from the state machine."""
chunk = await self.receive()
if chunk.get('type') == 'http.disconnect':
raise DisconnectedException()
if chunk.get('body', None) is not None:
assert isinstance(chunk['body'], bytes)
self.vm.notify_input(chunk['body'])
if not chunk.get('more_body', False):
self.vm.notify_input_closed()

async def take_and_send_output(self):
"""Take output from state machine and send it"""
output = self.vm.take_output()
Expand Down Expand Up @@ -417,21 +386,22 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
async def wrapper(f):
await f()
await self.take_and_send_output()
await self.sync_point.arrive()
await self.receive.tx({ 'type' : 'restate.run_completed', 'body' : bytes(), 'more_body' : True})

task = asyncio.create_task(wrapper(fn))
self.tasks.add(task)
continue
if isinstance(do_progress_response, (DoWaitPendingRun, DoProgressReadFromInput)):
sync_task = asyncio.create_task(self.sync_point.awaiter())
self.tasks.add(sync_task)

read_task = asyncio.create_task(self.receive_and_notify_input())
self.tasks.add(read_task)

done, _ = await asyncio.wait([sync_task, read_task], return_when=asyncio.FIRST_COMPLETED)
if read_task in done:
_ = read_task.result() # propagate exception
chunk = await self.receive()
if chunk.get('type') == 'restate.run_completed':
continue
if chunk.get('type') == 'http.disconnect':
raise DisconnectedException()
if chunk.get('body', None) is not None:
assert isinstance(chunk['body'], bytes)
self.vm.notify_input(chunk['body'])
if not chunk.get('more_body', False):
self.vm.notify_input_closed()

def _create_fetch_result_coroutine(self, handle: int, serde: Serde[T] | None = None):
"""Create a coroutine that fetches a result from a notification handle."""
Expand Down
42 changes: 41 additions & 1 deletion python/restate/server_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
:see: https://github.com/django/asgiref/blob/main/asgiref/typing.py
"""

import asyncio
from typing import (Awaitable, Callable, Dict, Iterable, List,
Tuple, Union, TypedDict, Literal, Optional, NotRequired, Any)

Expand Down Expand Up @@ -41,7 +42,7 @@ class Scope(TypedDict):

class HTTPRequestEvent(TypedDict):
"""ASGI Request event"""
type: Literal["http.request"]
type: Literal["http.request", "restate.run_completed"]
body: bytes
more_body: bool

Expand Down Expand Up @@ -86,3 +87,42 @@ def header_to_binary(headers: Iterable[Tuple[str, str]]) -> List[Tuple[bytes, by
def binary_to_header(headers: Iterable[Tuple[bytes, bytes]]) -> List[Tuple[str, str]]:
"""Convert a list of binary headers to a list of headers."""
return [ (k.decode('utf-8'), v.decode('utf-8')) for k,v in headers ]

class ReceiveChannel:
"""ASGI receive channel."""

def __init__(self, receive: Receive):
self.queue = asyncio.Queue[ASGIReceiveEvent]()

async def loop():
"""Receive loop."""
while True:
event = await receive()
await self.queue.put(event)
if event.get('type') == 'http.disconnect':
break

self.task = asyncio.create_task(loop())

async def rx(self) -> ASGIReceiveEvent:
"""Get the next message."""
what = await self.queue.get()
self.queue.task_done()
return what

async def __call__(self):
"""Get the next message."""
return await self.rx()

async def tx(self, what: ASGIReceiveEvent):
"""Add a message."""
await self.queue.put(what)

async def close(self):
"""Close the channel."""
if self.task and not self.task.done():
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass