3535I = TypeVar ('I' )
3636O = TypeVar ('O' )
3737
38+ class LazyFuture :
39+ """
40+ Creates a task lazily, and allows multiple awaiters to the same coroutine.
41+ The async_def will be executed at most 1 times. (0 if __await__ or get() not called)
42+ """
43+ __slots__ = ['async_def' , 'task' ]
44+
45+ def __init__ (self , async_def : Callable [[], typing .Coroutine [Any , Any , T ]]) -> None :
46+ assert async_def is not None
47+ self .async_def = async_def
48+ self .task : asyncio .Task | None = None
49+
50+ def done (self ):
51+ """
52+ check if completed
53+ """
54+ return self .task is not None and self .task .done ()
55+
56+ async def get (self ) -> T :
57+ """Get the value of the future."""
58+ if self .task is None :
59+ self .task = asyncio .create_task (self .async_def ())
60+
61+ return await self .task
62+
63+ def __await__ (self ):
64+ return self .get ().__await__ ()
3865
3966class ServerDurableFuture (RestateDurableFuture [T ]):
4067 """This class implements a durable future API"""
41- value : T | None = None
42- error : TerminalError | None = None
43- state : typing .Literal ["pending" , "fulfilled" , "rejected" ] = "pending"
4468
45- def __init__ (self , context : "ServerInvocationContext" , handle : int , awaitable_factory ) -> None :
69+ def __init__ (self , context : "ServerInvocationContext" , handle : int , async_def ) -> None :
4670 super ().__init__ ()
4771 self .context = context
48- self .source_notification_handle = handle
49- self .awaitable_factory = awaitable_factory
50- self .state = "pending"
51-
72+ self .handle = handle
73+ self .future = LazyFuture (async_def )
5274
5375 def is_completed (self ):
54- match self .state :
55- case "pending" :
56- return self .context .vm .is_completed (self .source_notification_handle )
57- case "fulfilled" :
58- return True
59- case "rejected" :
60- return True
61-
62- def map_value (self , mapper : Callable [[T ], O ]) -> RestateDurableFuture [O ]:
63- """Map the value of the future."""
64- async def mapper_coro ():
65- return mapper (await self )
66-
67- return ServerDurableFuture (self .context , self .source_notification_handle , mapper_coro )
68-
76+ """
77+ A future is completed, either it was physically completed and its value has been collected.
78+ OR it might not yet physically completed (i.e. the async_def didn't finish yet) BUT our VM
79+ already has a completion value for it.
80+ """
81+ return self .future .done () or self .context .vm .is_completed (self .handle )
6982
7083 def __await__ (self ):
71-
72- async def await_point ():
73- match self .state :
74- case "pending" :
75- try :
76- self .value = await self .awaitable_factory ()
77- self .state = "fulfilled"
78- return self .value
79- except TerminalError as t :
80- self .error = t
81- self .state = "rejected"
82- raise t
83- case "fulfilled" :
84- return self .value
85- case "rejected" :
86- assert self .error is not None
87- raise self .error
88-
89-
90- return await_point ().__await__ ()
84+ return self .future .__await__ ()
9185
9286class ServerCallDurableFuture (RestateDurableCallFuture [T ], ServerDurableFuture [T ]):
9387 """This class implements a durable future but for calls"""
94- _invocation_id : typing .Optional [str ] = None
9588
9689 def __init__ (self ,
9790 context : "ServerInvocationContext" ,
9891 result_handle : int ,
99- result_factory ,
100- invocation_id_handle : int ,
101- invocation_id_factory ) -> None :
102- super ().__init__ (context , result_handle , result_factory )
103- self .invocation_id_handle = invocation_id_handle
104- self .invocation_id_factory = invocation_id_factory
105-
92+ result_async_def ,
93+ invocation_id_async_def ) -> None :
94+ super ().__init__ (context , result_handle , result_async_def )
95+ self .invocation_id_future = LazyFuture (invocation_id_async_def )
10696
10797 async def invocation_id (self ) -> str :
10898 """Get the invocation id."""
109- if self ._invocation_id is None :
110- self ._invocation_id = await self .invocation_id_factory ()
111- return self ._invocation_id
99+ return await self .invocation_id_future .get ()
112100
113101class ServerSendHandle (SendHandle ):
114102 """This class implements the send API"""
115- _invocation_id : typing .Optional [str ]
116103
117- def __init__ (self , context , handle : int ) -> None :
104+ def __init__ (self , context : "ServerInvocationContext" , handle : int ) -> None :
118105 super ().__init__ ()
119- self .handle = handle
120- self .context = context
121- self ._invocation_id = None
106+
107+ async def coro ():
108+ if not context .vm .is_completed (handle ):
109+ await context .create_poll_or_cancel_coroutine ([handle ])
110+ return context .must_take_notification (handle )
111+
112+ self .future = LazyFuture (coro )
122113
123114 async def invocation_id (self ) -> str :
124115 """Get the invocation id."""
125- if self ._invocation_id is not None :
126- return self ._invocation_id
127- res = await self .context .create_poll_or_cancel_coroutine (self .handle )
128- self ._invocation_id = res
129- return res
116+ return await self .future
130117
131118async def async_value (n : Callable [[], T ]) -> T :
132119 """convert a simple value to a coroutine."""
@@ -334,6 +321,7 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
334321 continue
335322 if isinstance (do_progress_response , DoProgressExecuteRun ):
336323 fn = self .run_coros_to_execute [do_progress_response .handle ]
324+ del self .run_coros_to_execute [do_progress_response .handle ]
337325 assert fn is not None
338326
339327 async def wrapper (f ):
@@ -346,11 +334,12 @@ async def wrapper(f):
346334 if isinstance (do_progress_response , DoWaitPendingRun ):
347335 await self .sync_point .wait ()
348336
349- def create_df (self , handle : int , serde : Serde [T ] | None = None ) -> ServerDurableFuture [T ]:
337+ def create_future (self , handle : int , serde : Serde [T ] | None = None ) -> ServerDurableFuture [T ]:
350338 """Create a durable future."""
351339
352340 async def transform ():
353- await self .create_poll_or_cancel_coroutine ([handle ])
341+ if not self .vm .is_completed (handle ):
342+ await self .create_poll_or_cancel_coroutine ([handle ])
354343 res = self .must_take_notification (handle )
355344 if res is None or serde is None :
356345 return res
@@ -359,30 +348,31 @@ async def transform():
359348 return ServerDurableFuture (self , handle , transform )
360349
361350
362-
363- def create_call_df (self , handle : int , invocation_id_handle : int , serde : Serde [T ] | None = None ) -> ServerCallDurableFuture [T ]:
351+ def create_call_future (self , handle : int , invocation_id_handle : int , serde : Serde [T ] | None = None ) -> ServerCallDurableFuture [T ]:
364352 """Create a durable future."""
365353
366354 async def transform ():
367- await self .create_poll_or_cancel_coroutine ([handle ])
355+ if not self .vm .is_completed (handle ):
356+ await self .create_poll_or_cancel_coroutine ([handle ])
368357 res = self .must_take_notification (handle )
369358 if res is None or serde is None :
370359 return res
371360 return serde .deserialize (res )
372361
373362 async def inv_id_factory ():
374- await self .create_poll_or_cancel_coroutine ([invocation_id_handle ])
363+ if not self .vm .is_completed (invocation_id_handle ):
364+ await self .create_poll_or_cancel_coroutine ([invocation_id_handle ])
375365 return self .must_take_notification (invocation_id_handle )
376366
377- return ServerCallDurableFuture (self , handle , transform , invocation_id_handle , inv_id_factory )
367+ return ServerCallDurableFuture (self , handle , transform , inv_id_factory )
378368
379369
380370 def get (self , name : str , serde : Serde [T ] = JsonSerde ()) -> Awaitable [Optional [T ]]:
381371 handle = self .vm .sys_get_state (name )
382- return self .create_df (handle , serde ) # type: ignore
372+ return self .create_future (handle , serde ) # type: ignore
383373
384374 def state_keys (self ) -> Awaitable [List [str ]]:
385- return self .create_df (self .vm .sys_get_state_keys ()) # type: ignore
375+ return self .create_future (self .vm .sys_get_state_keys ()) # type: ignore
386376
387377 def set (self , name : str , value : T , serde : Serde [T ] = JsonSerde ()) -> None :
388378 """Set the value associated with the given name."""
@@ -446,13 +436,13 @@ def run(self,
446436 self .run_coros_to_execute [handle ] = lambda : self .create_run_coroutine (handle , action , serde , max_attempts , max_retry_duration )
447437
448438 # Prepare response coroutine
449- return self .create_df (handle , serde ) # type: ignore
439+ return self .create_future (handle , serde ) # type: ignore
450440
451441
452442 def sleep (self , delta : timedelta ) -> RestateDurableFuture [None ]:
453443 # convert timedelta to milliseconds
454444 millis = int (delta .total_seconds () * 1000 )
455- return self .create_df (self .vm .sys_sleep (millis )) # type: ignore
445+ return self .create_future (self .vm .sys_sleep (millis )) # type: ignore
456446
457447 def do_call (self ,
458448 tpe : Callable [[Any , I ], Awaitable [O ]],
@@ -501,7 +491,7 @@ def do_raw_call(self,
501491 idempotency_key = idempotency_key ,
502492 headers = headers )
503493
504- return self .create_call_df (handle = handle .result_handle ,
494+ return self .create_call_future (handle = handle .result_handle ,
505495 invocation_id_handle = handle .invocation_id_handle ,
506496 serde = output_serde )
507497
@@ -582,7 +572,7 @@ def awakeable(self,
582572 serde : typing .Optional [Serde [I ]] = JsonSerde ()) -> typing .Tuple [str , RestateDurableFuture [Any ]]:
583573 assert serde is not None
584574 name , handle = self .vm .sys_awakeable ()
585- return name , self .create_df (handle , serde )
575+ return name , self .create_future (handle , serde )
586576
587577 def resolve_awakeable (self ,
588578 name : str ,
@@ -613,4 +603,4 @@ def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -
613603 raise ValueError ("invocation_id cannot be None" )
614604 assert serde is not None
615605 handle = self .vm .attach_invocation (invocation_id )
616- return self .create_df (handle , serde )
606+ return self .create_future (handle , serde )
0 commit comments