Skip to content

Support concurrent side effects more robustly #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,4 @@ jobs:
serviceContainerImage: "restatedev/test-services-python"
exclusionsFile: "test-services/exclusions.yaml"
testArtifactOutput: "sdk-python-integration-test-report"
serviceContainerEnvFile: "test-services/.env"
54 changes: 28 additions & 26 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ async def cancel_invocation(self) -> None:

await ctx.cancel_invocation(await f.invocation_id())
"""
inv = await self.invocation_id()
await self.context.cancel_invocation(inv)

class ServerSendHandle(SendHandle):
"""This class implements the send API"""
Expand Down Expand Up @@ -223,45 +225,47 @@ class SyncPoint:
This class implements a synchronization point.
"""

def __init__(self):
self._cond = asyncio.Condition()
def __init__(self) -> None:
self.cond: asyncio.Event | None = None

async def wait(self):
def awaiter(self):
"""Wait for the sync point."""
async with self._cond:
await self._cond.wait()
if self.cond is None:
self.cond = asyncio.Event()
return self.cond.wait()

async def arrive(self):
"""Arrive at the sync point."""
async with self._cond:
self._cond.notify_all()
"""arrive at the sync point."""
if self.cond is not None:
self.cond.set()

class Tasks:
"""
This class implements a list of tasks.
"""

def __init__(self) -> None:
self.tasks: List[asyncio.Future] = []
self.tasks: set[asyncio.Future] = set()

def add(self, task: asyncio.Future):
"""Add a task to the list."""
self.tasks.append(task)
self.tasks.add(task)

def safe_remove(_):
"""Remove the task from the list."""
try:
self.tasks.remove(task)
except ValueError:
except KeyError:
pass

task.add_done_callback(safe_remove)

def cancel(self):
"""Cancel all tasks in the list."""
for task in self.tasks:
task.cancel()
to_cancel = list(self.tasks)
self.tasks.clear()
for task in to_cancel:
task.cancel()

# pylint: disable=R0902
class ServerInvocationContext(ObjectContext):
Expand Down Expand Up @@ -358,13 +362,13 @@ def on_attempt_finished(self):
async def receive_and_notify_input(self):
"""Receive input from the state machine."""
chunk = await self.receive()
if chunk.get('type') == 'http.request':
if chunk.get('type') == 'http.disconnect':
raise DisconnectedException()
if chunk.get('body', None) is not None:
assert isinstance(chunk['body'], bytes)
self.vm.notify_input(chunk['body'])
if not chunk.get('more_body', False):
self.vm.notify_input_closed()
if chunk.get('type') == 'http.disconnect':
raise DisconnectedException()

async def take_and_send_output(self):
"""Take output from state machine and send it"""
Expand Down Expand Up @@ -398,9 +402,6 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
return
if isinstance(do_progress_response, DoProgressCancelSignalReceived):
raise TerminalError("cancelled", 409)
if isinstance(do_progress_response, DoProgressReadFromInput):
await self.receive_and_notify_input()
continue
if isinstance(do_progress_response, DoProgressExecuteRun):
fn = self.run_coros_to_execute[do_progress_response.handle]
del self.run_coros_to_execute[do_progress_response.handle]
Expand All @@ -414,17 +415,16 @@ async def wrapper(f):
task = asyncio.create_task(wrapper(fn))
self.tasks.add(task)
continue
if isinstance(do_progress_response, DoWaitPendingRun):
sync_task = asyncio.create_task(self.sync_point.wait())
read_task = asyncio.create_task(self.receive_and_notify_input())
if isinstance(do_progress_response, (DoWaitPendingRun, DoProgressReadFromInput)):
sync_task = asyncio.create_task(self.sync_point.awaiter())
self.tasks.add(sync_task)

read_task = asyncio.create_task(self.receive_and_notify_input())
self.tasks.add(read_task)

done, _ = await asyncio.wait([sync_task, read_task], return_when=asyncio.FIRST_COMPLETED)
if read_task in done:
_ = read_task.result() # rethrow any exception
if sync_task in done:
continue

_ = read_task.result() # propagate exception

def _create_fetch_result_coroutine(self, handle: int, serde: Serde[T] | None = None):
"""Create a coroutine that fetches a result from a notification handle."""
Expand Down Expand Up @@ -520,6 +520,8 @@ async def create_run_coroutine(self,
except TerminalError as t:
failure = Failure(code=t.status_code, message=t.message)
self.vm.propose_run_completion_failure(handle, failure)
except asyncio.CancelledError as e:
raise e from None
# pylint: disable=W0718
except Exception as e:
if max_attempts is None and max_retry_duration is None:
Expand Down
1 change: 1 addition & 0 deletions test-services/.env
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
RESTATE_CORE_LOG=trace