88import uuid
99import weakref
1010from abc import ABC , abstractmethod
11- from collections .abc import Awaitable , Sequence
11+ from collections .abc import Awaitable
1212from concurrent .futures import Future
1313from dataclasses import dataclass , field
1414from threading import Thread
3535
3636_R = TypeVar ('_R' ) # Return type for collective_rpc
3737
38+ STARTUP_POLL_PERIOD_MS = 10000
39+
3840
3941class EngineCoreClient (ABC ):
4042 """
@@ -261,15 +263,13 @@ def __init__(
261263 vllm_config : VllmConfig ,
262264 executor_class : type [Executor ],
263265 log_stats : bool ,
264- ctx : Union [ zmq . Context , zmq . asyncio . Context ] ,
266+ input_path : str ,
265267 output_path : str ,
266268 index : int = 0 ,
267269 local_dp_rank : int = 0 ,
268270 ):
269- # Paths and sockets for IPC.
270- input_path = get_open_zmq_ipc_path ()
271- self .input_socket = make_zmq_socket (ctx , input_path ,
272- zmq .constants .PUSH )
271+ self .index = index
272+ self .identity = index .to_bytes (length = 2 , byteorder = "little" )
273273 try :
274274 # Start EngineCore in background process.
275275 self .proc_handle = BackgroundProcHandle (
@@ -291,14 +291,9 @@ def __init__(
291291 # Ensure socket is closed if process fails to start.
292292 self .close ()
293293
294- def send_multipart (self , msg_parts : Sequence ):
295- return self .input_socket .send_multipart (msg_parts , copy = False )
296-
297294 def close (self ):
298295 if proc_handle := getattr (self , "proc_handle" , None ):
299296 proc_handle .shutdown ()
300- if socket := getattr (self , "input_socket" , None ):
301- socket .close (linger = 0 )
302297
303298
304299@dataclass
@@ -309,6 +304,7 @@ class BackgroundResources:
309304 ctx : Union [zmq .Context ]
310305 core_engines : list [CoreEngine ] = field (default_factory = list )
311306 output_socket : Optional [Union [zmq .Socket , zmq .asyncio .Socket ]] = None
307+ input_socket : Optional [Union [zmq .Socket , zmq .asyncio .Socket ]] = None
312308 shutdown_path : Optional [str ] = None
313309
314310 def __call__ (self ):
@@ -321,6 +317,8 @@ def __call__(self):
321317 # aren't explicitly closed first.
322318 if self .output_socket is not None :
323319 self .output_socket .close (linger = 0 )
320+ if self .input_socket is not None :
321+ self .input_socket .close (linger = 0 )
324322 if self .shutdown_path is not None :
325323 # We must ensure that the sync output socket is
326324 # closed cleanly in its own thread.
@@ -387,21 +385,51 @@ def sigusr1_handler(signum, frame):
387385
388386 # Paths and sockets for IPC.
389387 self .output_path = get_open_zmq_ipc_path ()
388+ input_path = get_open_zmq_ipc_path ()
389+ self .input_socket = make_zmq_socket (self .ctx ,
390+ input_path ,
391+ zmq .ROUTER ,
392+ bind = True )
393+ self .resources .input_socket = self .input_socket
390394
391395 new_core_engine = lambda index , local_dp_rank = None : CoreEngine (
392- vllm_config , executor_class , log_stats , self . ctx , self .output_path ,
393- index , local_dp_rank )
396+ vllm_config , executor_class , log_stats , input_path , self .
397+ output_path , index , local_dp_rank )
394398
395399 # Start engine core process(es).
396400 self ._init_core_engines (vllm_config , new_core_engine ,
397401 self .resources .core_engines )
398402
399403 # Wait for engine core process(es) to start.
400- for engine in self .resources .core_engines :
401- engine .proc_handle .wait_for_startup ()
404+ self ._wait_for_engine_startup ()
402405
403406 self .utility_results : dict [int , AnyFuture ] = {}
404407
408+ def _wait_for_engine_startup (self ):
409+ # Get a sync handle to the socket which can be sync or async.
410+ sync_input_socket = zmq .Socket .shadow (self .input_socket )
411+
412+ # Wait for engine core process(es) to send ready messages.
413+ identities = set (eng .index for eng in self .resources .core_engines )
414+ while identities :
415+ while not sync_input_socket .poll (timeout = STARTUP_POLL_PERIOD_MS ):
416+ logger .info ("Waiting for %d core engine proc(s) to start: %s" ,
417+ len (identities ), identities )
418+ eng_id_bytes , msg = sync_input_socket .recv_multipart ()
419+ eng_id = int .from_bytes (eng_id_bytes , byteorder = "little" )
420+ if eng_id not in identities :
421+ raise RuntimeError (f"Unexpected or duplicate engine: { eng_id } " )
422+ if msg != b'READY' :
423+ raise RuntimeError (f"Engine { eng_id } failed: { msg .decode ()} " )
424+ logger .info ("Core engine process %d ready." , eng_id )
425+ identities .discard (eng_id )
426+
427+ # Double check that the process are running.
428+ for engine in self .resources .core_engines :
429+ proc = engine .proc_handle .proc
430+ if proc .exitcode is not None :
431+ raise RuntimeError (f"Engine proc { proc .name } not running" )
432+
405433 def _init_core_engines (
406434 self ,
407435 vllm_config : VllmConfig ,
@@ -494,9 +522,10 @@ def get_output(self) -> EngineCoreOutputs:
494522 return self .outputs_queue .get ()
495523
496524 def _send_input (self , request_type : EngineCoreRequestType , request : Any ):
497- # (RequestType, SerializedRequest)
498- msg = (request_type .value , self .encoder .encode (request ))
499- self .core_engine .send_multipart (msg )
525+ # (Identity, RequestType, SerializedRequest)
526+ msg = (self .core_engine .identity , request_type .value ,
527+ self .encoder .encode (request ))
528+ self .input_socket .send_multipart (msg , copy = False )
500529
501530 def call_utility (self , method : str , * args ) -> Any :
502531 call_id = uuid .uuid1 ().int >> 64
@@ -625,30 +654,34 @@ async def get_output_async(self) -> EngineCoreOutputs:
625654 assert self .outputs_queue is not None
626655 return await self .outputs_queue .get ()
627656
628- async def _send_input (self , request_type : EngineCoreRequestType ,
629- request : Any ) -> None :
630- await self .core_engine .send_multipart (
631- (request_type .value , self .encoder .encode (request )))
657+ def _send_input (self ,
658+ request_type : EngineCoreRequestType ,
659+ request : Any ,
660+ engine : Optional [CoreEngine ] = None ) -> Awaitable [None ]:
661+ if engine is None :
662+ engine = self .core_engine
632663
633- self ._ensure_output_queue_task ()
664+ message = (request_type .value , self .encoder .encode (request ))
665+ return self ._send_input_message (message , engine )
666+
667+ def _send_input_message (self , message : tuple [bytes , bytes ],
668+ engine : CoreEngine ) -> Awaitable [None ]:
669+ message = (engine .identity , ) + message # type: ignore[assignment]
670+ return self .input_socket .send_multipart (message , copy = False )
634671
635672 async def call_utility_async (self , method : str , * args ) -> Any :
636673 return await self ._call_utility_async (method ,
637674 * args ,
638675 engine = self .core_engine )
639676
640- async def _call_utility_async (
641- self ,
642- method : str ,
643- * args ,
644- engine : CoreEngine ,
645- ) -> Any :
677+ async def _call_utility_async (self , method : str , * args ,
678+ engine : CoreEngine ) -> Any :
646679 call_id = uuid .uuid1 ().int >> 64
647680 future = asyncio .get_running_loop ().create_future ()
648681 self .utility_results [call_id ] = future
649682 message = (EngineCoreRequestType .UTILITY .value ,
650683 self .encoder .encode ((call_id , method , args )))
651- await engine . send_multipart (message )
684+ await self . _send_input_message (message , engine )
652685 self ._ensure_output_queue_task ()
653686 return await future
654687
@@ -657,6 +690,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
657690 # tokenized.
658691 request .prompt = None
659692 await self ._send_input (EngineCoreRequestType .ADD , request )
693+ self ._ensure_output_queue_task ()
660694
661695 async def abort_requests_async (self , request_ids : list [str ]) -> None :
662696 if len (request_ids ) > 0 :
@@ -761,15 +795,15 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
761795 self .reqs_in_flight [request .request_id ] = chosen_engine
762796 chosen_engine .num_reqs_in_flight += 1
763797 if self .num_engines_running >= len (self .core_engines ):
764- await chosen_engine . send_multipart (msg )
798+ await self . _send_input_message (msg , chosen_engine )
765799 else :
766800 # Send request to chosen engine and dp start loop
767801 # control message to all other engines.
768802 self .num_engines_running += len (self .core_engines )
769803 await asyncio .gather (* [
770- engine . send_multipart ( msg if engine is
771- chosen_engine else self .start_dp_msg )
772- for engine in self .core_engines
804+ self . _send_input_message (
805+ msg if engine is chosen_engine else self .start_dp_msg ,
806+ engine ) for engine in self .core_engines
773807 ])
774808
775809 self ._ensure_output_queue_task ()
@@ -794,7 +828,7 @@ async def process_engine_outputs(self: "DPAsyncMPClient",
794828 # sure to start the other engines:
795829 self .num_engines_running = len (self .core_engines )
796830 coros = [
797- engine . send_multipart (self .start_dp_msg )
831+ self . _send_input_message (self .start_dp_msg , engine )
798832 for engine in self .core_engines
799833 if not engine .num_reqs_in_flight
800834 ]
@@ -820,5 +854,5 @@ async def abort_requests_async(self, request_ids: list[str]) -> None:
820854
821855 async def _abort_requests (self , request_ids : list [str ],
822856 engine : CoreEngine ) -> None :
823- await engine . send_multipart (( EngineCoreRequestType .ABORT . value ,
824- self . encoder . encode ( request_ids )) )
857+ await self . _send_input ( EngineCoreRequestType .ABORT , request_ids ,
858+ engine )
0 commit comments