11# SPDX-License-Identifier: Apache-2.0 
22import  asyncio 
3+ import  contextlib 
34import  queue 
45import  uuid 
56import  weakref 
67from  abc  import  ABC , abstractmethod 
8+ from  collections  import  deque 
79from  collections .abc  import  Awaitable , Sequence 
810from  concurrent .futures  import  Future 
911from  dataclasses  import  dataclass , field 
@@ -396,6 +398,12 @@ def __init__(
396398            self ._wait_for_engine_startup ()
397399
398400            self .utility_results : dict [int , AnyFuture ] =  {}
401+ 
402+             # Request objects which may contain pytorch-allocated tensors 
403+             # that we need to keep references to until zmq is done with the 
404+             # underlying data. 
405+             self .pending_messages  =  deque [tuple [zmq .MessageTracker , Any ]]()
406+ 
399407            success  =  True 
400408        finally :
401409            if  not  success :
@@ -459,6 +467,14 @@ def ensure_alive(self):
459467        if  self .resources .engine_dead :
460468            raise  EngineDeadError ()
461469
470+     def  add_pending_message (self , tracker : zmq .MessageTracker , msg : Any ):
471+         if  not  tracker .done :
472+             self .pending_messages .appendleft ((tracker , msg ))
473+ 
474+     def  free_pending_messages (self ):
475+         while  self .pending_messages  and  self .pending_messages [- 1 ][0 ].done :
476+             self .pending_messages .pop ()
477+ 
462478
463479def  _process_utility_output (output : UtilityOutput ,
464480                            utility_results : dict [int , AnyFuture ]):
@@ -544,10 +560,18 @@ def get_output(self) -> EngineCoreOutputs:
544560
545561    def  _send_input (self , request_type : EngineCoreRequestType , request : Any ):
546562        self .ensure_alive ()
563+         self .free_pending_messages ()
547564        # (Identity, RequestType, SerializedRequest) 
548565        msg  =  (self .core_engine .identity , request_type .value ,
549566               * self .encoder .encode (request ))
550-         self .input_socket .send_multipart (msg , copy = False )
567+ 
568+         if  len (msg ) <=  3 :
569+             # No auxiliary buffers => no tensor backing buffers in request. 
570+             self .input_socket .send_multipart (msg , copy = False )
571+             return 
572+ 
573+         tracker  =  self .input_socket .send_multipart (msg , copy = False , track = True )
574+         self .add_pending_message (tracker , request )
551575
552576    def  call_utility (self , method : str , * args ) ->  Any :
553577        call_id  =  uuid .uuid1 ().int  >>  64 
@@ -698,19 +722,38 @@ async def get_output_async(self) -> EngineCoreOutputs:
698722    def  _send_input (self ,
699723                    request_type : EngineCoreRequestType ,
700724                    request : Any ,
701-                     engine : Optional [CoreEngine ] =  None ) ->  Awaitable [None ]:
725+                     engine : Optional [CoreEngine ] =  None ) ->  Awaitable [Any ]:
702726        self .ensure_alive ()
703727        if  engine  is  None :
704728            engine  =  self .core_engine 
705729
706730        message  =  (request_type .value , * self .encoder .encode (request ))
707-         return  self ._send_input_message (message , engine )
708- 
709-     def  _send_input_message (self , message : tuple [bytestr , ...],
710-                             engine : CoreEngine ) ->  Awaitable [None ]:
731+         return  self ._send_input_message (message , engine , request )
732+ 
733+     def  _send_input_message (self , message : tuple [bytestr ,
734+                                                  ...], engine : CoreEngine ,
735+                             objects : Any ) ->  Awaitable [Any ]:
736+         """ 
737+         objects is a reference to retain until zmq is finished with the 
738+         buffers, in case they were extracted from tensors in the request. 
739+         """ 
711740        self .ensure_alive ()
712-         message  =  (engine .identity , ) +  message 
713-         return  self .input_socket .send_multipart (message , copy = False )
741+         self .free_pending_messages ()
742+ 
743+         msg  =  (engine .identity , ) +  message 
744+         if  not  objects  or  len (msg ) <=  3 :
745+             # No auxiliary buffers => no tensor backing buffers in request. 
746+             return  self .input_socket .send_multipart (msg , copy = False )
747+ 
748+         future : asyncio .Future [zmq .MessageTracker ]
749+         future  =  self .input_socket .send_multipart (msg , copy = False , track = True )
750+ 
751+         def  add_pending (f : asyncio .Future [zmq .MessageTracker ]):
752+             with  contextlib .suppress (BaseException ):
753+                 self .add_pending_message (f .result (), objects )
754+ 
755+         future .add_done_callback (add_pending )
756+         return  future 
714757
715758    async  def  call_utility_async (self , method : str , * args ) ->  Any :
716759        return  await  self ._call_utility_async (method ,
@@ -724,7 +767,7 @@ async def _call_utility_async(self, method: str, *args,
724767        self .utility_results [call_id ] =  future 
725768        message  =  (EngineCoreRequestType .UTILITY .value , * self .encoder .encode (
726769            (call_id , method , args )))
727-         await  self ._send_input_message (message , engine )
770+         await  self ._send_input_message (message , engine ,  args )
728771        self ._ensure_output_queue_task ()
729772        return  await  future 
730773
0 commit comments