diff --git a/examples/workflow.py b/examples/workflow.py index a21faa0..2cf7eaa 100644 --- a/examples/workflow.py +++ b/examples/workflow.py @@ -13,9 +13,13 @@ # pylint: disable=W0613 # pylint: disable=C0301 +from datetime import timedelta from restate import Workflow, WorkflowContext, WorkflowSharedContext -from restate.exceptions import TerminalError +from restate import select +from restate import TerminalError + +TIMEOUT = timedelta(seconds=10) payment = Workflow("payment") @@ -38,13 +42,17 @@ def payment_gateway(): ctx.set("status", "waiting for the payment provider to approve") # Wait for the payment to be verified - result = await ctx.promise("verify.payment").value() - if result == "approved": - ctx.set("status", "payment approved") - return { "success" : True } - ctx.set("status", "payment declined") - raise TerminalError(message="Payment declined", status_code=401) + match await select(result=ctx.promise("verify.payment").value(), timeout=ctx.sleep(TIMEOUT)): + case ['result', "approved"]: + ctx.set("status", "payment approved") + return { "success" : True } + case ['result', "declined"]: + ctx.set("status", "payment declined") + raise TerminalError(message="Payment declined", status_code=401) + case ['timeout', _]: + ctx.set("status", "payment verification timed out") + raise TerminalError(message="Payment verification timed out", status_code=410) @payment.handler() async def payment_verified(ctx: WorkflowSharedContext, result: str): diff --git a/python/restate/__init__.py b/python/restate/__init__.py index 790d88b..64648e3 100644 --- a/python/restate/__init__.py +++ b/python/restate/__init__.py @@ -22,7 +22,7 @@ # pylint: disable=line-too-long from .context import DurablePromise, RestateDurableFuture, RestateDurableCallFuture, RestateDurableSleepFuture, SendHandle from .exceptions import TerminalError -from .asyncio import as_completed, gather, wait_completed +from .asyncio import as_completed, gather, wait_completed, select from .endpoint import app @@ -56,4 +56,5 @@ def test_harness(app, follow_logs = False, restate_image = ""): # type: ignore "gather", "as_completed", "wait_completed", + "select" ] diff --git a/python/restate/asyncio.py b/python/restate/asyncio.py index 0fc516e..9417134 100644 --- a/python/restate/asyncio.py +++ b/python/restate/asyncio.py @@ -27,6 +27,38 @@ async def gather(*futures: RestateDurableFuture[Any]) -> List[RestateDurableFutu pass return list(futures) +async def select(**kws: RestateDurableFuture[Any]) -> List[Any]: + """ + Blocks until one of the futures is completed. + + Example: + + who, what = await select(car=f1, hotel=f2, flight=f3) + if who == "car": + print(what) + elif who == "hotel": + print(what) + elif who == "flight": + print(what) + + works the best with matching: + + match await select(result=ctx.promise("verify.payment"), timeout=ctx.sleep(timedelta(seconds=10))): + case ['result', "approved"]: + return { "success" : True } + case ['result', "declined"]: + raise TerminalError(message="Payment declined", status_code=401) + case ['timeout', _]: + raise TerminalError(message="Payment verification timed out", status_code=410) + + """ + if not kws: + raise ValueError("At least one future must be passed.") + reverse = { f: key for key, f in kws.items() } + async for f in as_completed(*kws.values()): + return [reverse[f], await f] + assert False, "unreachable" + async def as_completed(*futures: RestateDurableFuture[Any]): """ Returns an iterator that yields the futures as they are completed.