Skip to content

Commit 4fe56d1

Browse files
committed
Lift take async result outside of a poll coro
1 parent 215aaf2 commit 4fe56d1

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

python/restate/server_context.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -246,19 +246,19 @@ def must_take_notification(self, handle):
246246
return res
247247

248248

249-
async def create_poll_or_cancel_coroutine(self, handle) -> bytes | None:
249+
async def create_poll_or_cancel_coroutine(self, handles):
250250
"""Create a coroutine to poll the handle."""
251251
await self.take_and_send_output()
252252
while True:
253-
if self.vm.is_completed(handle):
253+
#if self.vm.is_completed(handle):
254254
# Handle is completed
255-
return self.must_take_notification(handle)
255+
# return self.must_take_notification(handle)
256256

257257
# Nothing ready yet, let's try to make some progress
258-
do_progress_response = self.vm.do_progress([handle])
258+
do_progress_response = self.vm.do_progress(handles)
259259
if isinstance(do_progress_response, DoProgressAnyCompleted):
260260
# One of the handles completed, we can continue
261-
continue
261+
return
262262
if isinstance(do_progress_response, DoProgressCancelSignalReceived):
263263
# Raise cancel signal
264264
raise TerminalError("cancelled", 409)
@@ -280,7 +280,8 @@ def create_df(self, handle: int, serde: Serde[T] | None = None) -> ServerDurable
280280
"""Create a durable future."""
281281

282282
async def transform():
283-
res = await self.create_poll_or_cancel_coroutine(handle)
283+
await self.create_poll_or_cancel_coroutine(handle)
284+
res = self.must_take_notification(handle)
284285
if res is None or serde is None:
285286
return res
286287
return serde.deserialize(res)
@@ -293,13 +294,15 @@ def create_call_df(self, handle: int, invocation_id_handle: int, serde: Serde[T]
293294
"""Create a durable future."""
294295

295296
async def transform():
296-
res = await self.create_poll_or_cancel_coroutine(handle)
297+
await self.create_poll_or_cancel_coroutine(handle)
298+
res = self.must_take_notification(handle)
297299
if res is None or serde is None:
298300
return res
299301
return serde.deserialize(res)
300302

301-
def inv_id_factory():
302-
return self.create_poll_or_cancel_coroutine(invocation_id_handle)
303+
async def inv_id_factory():
304+
await self.create_poll_or_cancel_coroutine(invocation_id_handle)
305+
return self.must_take_notification(invocation_id_handle)
303306

304307
return ServerCallDurableFuture(handle, transform, invocation_id_handle, inv_id_factory)
305308

0 commit comments

Comments
 (0)