@@ -325,7 +325,7 @@ def __init__(
325325
326326 self .step_fn = (self .step if self .batch_queue is None else
327327 self .step_with_batch_queue )
328- self .global_unfinished_reqs = False
328+ self .engines_running = False
329329
330330 # Background Threads and Queues for IO. These enable us to
331331 # overlap ZMQ socket IO with GPU since they release the GIL,
@@ -410,19 +410,15 @@ def _process_input_queue(self):
410410 """Exits when an engine step needs to be performed."""
411411
412412 waited = False
413- while not self .global_unfinished_reqs and not (
414- self .scheduler .has_requests ()):
413+ while not self .engines_running and not (self .scheduler .has_requests ()):
415414 if logger .isEnabledFor (DEBUG ) and self .input_queue .empty ():
416415 logger .debug ("EngineCore waiting for work." )
417416 waited = True
418417 req = self .input_queue .get ()
419418 self ._handle_client_request (* req )
420419
421420 if waited :
422- logger .debug (
423- "EngineCore loop active - local unfinished: %s, finished: %s." ,
424- self .scheduler .has_unfinished_requests (),
425- self .scheduler .has_finished_requests ())
421+ logger .debug ("EngineCore loop active." )
426422
427423 # Handle any more client requests.
428424 while not self .input_queue .empty ():
@@ -446,10 +442,6 @@ def _handle_client_request(self, request_type: EngineCoreRequestType,
446442 self .add_request (request )
447443 elif request_type == EngineCoreRequestType .ABORT :
448444 self .abort_requests (request )
449- elif request_type == EngineCoreRequestType .START_DP :
450- if not self .global_unfinished_reqs :
451- logger .debug ("EngineCore starting idle loop." )
452- self .global_unfinished_reqs = True
453445 elif request_type == EngineCoreRequestType .UTILITY :
454446 call_id , method_name , args = request
455447 output = UtilityOutput (call_id )
@@ -548,9 +540,6 @@ def process_output_socket(self, output_path: str, engine_index: int):
548540 socket .send_multipart (buffers , copy = False )
549541
550542
551- ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs (engine_paused = True )
552-
553-
554543class DPEngineCoreProc (EngineCoreProc ):
555544 """ZMQ-wrapper for running EngineCore in background process
556545 in a data parallel context."""
@@ -587,7 +576,9 @@ def __init__(
587576 for i in range (local_dp_rank * tp_size , (local_dp_rank + 1 ) *
588577 tp_size ))
589578
579+ self .local_dp_rank = local_dp_rank
590580 self .dp_group = vllm_config .parallel_config .stateless_init_dp_group ()
581+ self .current_wave = 0
591582
592583 # Initialize the engine after setting up environment.
593584 super ().__init__ (input_path , output_path , vllm_config , executor_class ,
@@ -602,6 +593,31 @@ def shutdown(self):
602593 if dp_group := getattr (self , "dp_group" , None ):
603594 stateless_destroy_torch_distributed_process_group (dp_group )
604595
596+ def add_request (self , request : EngineCoreRequest ):
597+ if request .current_wave != self .current_wave :
598+ if request .current_wave > self .current_wave :
599+ self .current_wave = request .current_wave
600+ elif not self .engines_running :
601+ # Request received for an already-completed wave, notify
602+ # front-end that we need to start the next one.
603+ self .output_queue .put_nowait (
604+ EngineCoreOutputs (start_wave = self .current_wave ))
605+
606+ super ().add_request (request )
607+
608+ def _handle_client_request (self , request_type : EngineCoreRequestType ,
609+ request : Any ) -> None :
610+ if request_type == EngineCoreRequestType .START_DP_WAVE :
611+ new_wave : int = request
612+ if new_wave >= self .current_wave :
613+ self .current_wave = new_wave
614+ if not self .engines_running :
615+ logger .debug ("EngineCore starting idle loop for wave %d." ,
616+ new_wave )
617+ self .engines_running = True
618+ else :
619+ super ()._handle_client_request (request_type , request )
620+
605621 def run_busy_loop (self ):
606622 """Core busy loop of the EngineCore for data parallel case."""
607623
@@ -628,7 +644,7 @@ def run_busy_loop(self):
628644 # up-to-date state is returned in the engine outputs.
629645 self ._process_engine_step ()
630646
631- if not self .global_unfinished_reqs :
647+ if not self .engines_running :
632648 # All engines are idle.
633649 continue
634650
@@ -637,18 +653,23 @@ def run_busy_loop(self):
637653 self .execute_dummy_batch ()
638654
639655 # 3) All-reduce operation to determine global unfinished reqs.
640- self .global_unfinished_reqs = self ._has_global_unfinished_reqs (
656+ self .engines_running = self ._has_global_unfinished_reqs (
641657 local_unfinished_reqs )
642658
643- if not self .global_unfinished_reqs :
644- # Notify client that we are pausing the loop.
645- self .output_queue .put_nowait (ENGINE_PAUSED_OUTPUTS )
659+ if not self .engines_running :
660+ if self .local_dp_rank == 0 :
661+ # Notify client that we are pausing the loop.
662+ logger .debug ("Wave %d finished, pausing engine loop." ,
663+ self .current_wave )
664+ self .output_queue .put_nowait (
665+ EngineCoreOutputs (wave_complete = self .current_wave ))
666+ self .current_wave += 1
646667
647668 def _has_global_unfinished_reqs (self , local_unfinished : bool ) -> bool :
648669
649- # Optimization - only perform finish-sync all-reduce every 16 steps.
670+ # Optimization - only perform finish-sync all-reduce every 24 steps.
650671 self .counter += 1
651- if self .counter != 16 :
672+ if self .counter != 24 :
652673 return True
653674 self .counter = 0
654675
0 commit comments