44import signal
55import threading
66import time
7+ from concurrent .futures import Future
78from multiprocessing .connection import Connection
8- from typing import Any , List , Tuple , Type
9+ from typing import Any , List , Optional , Tuple , Type
910
1011import psutil
1112import zmq
1819 maybe_register_config_serialize_by_value )
1920from vllm .utils import get_exception_traceback , zmq_socket_ctx
2021from vllm .v1 .core .kv_cache_utils import get_kv_cache_configs
21- from vllm .v1 .core .scheduler import Scheduler
22+ from vllm .v1 .core .scheduler import Scheduler , SchedulerOutput
2223from vllm .v1 .engine import (EngineCoreOutputs , EngineCoreRequest ,
2324 EngineCoreRequestType )
2425from vllm .v1 .engine .mm_input_cache import MMInputCacheServer
2526from vllm .v1 .executor .abstract import Executor
27+ from vllm .v1 .outputs import ModelRunnerOutput
2628from vllm .v1 .request import Request , RequestStatus
2729from vllm .v1 .serial_utils import MsgpackDecoder , MsgpackEncoder
2830from vllm .version import __version__ as VLLM_VERSION
@@ -66,9 +68,22 @@ def __init__(
6668 log_stats = self .log_stats ,
6769 )
6870
71+ # Setup MM Input Mapper.
6972 self .mm_input_cache_server = MMInputCacheServer (
7073 vllm_config .model_config )
7174
75+ # Setup batch queue for pipeline parallelism.
76+ # Batch queue for scheduled batches. This enables us to asynchronously
77+ # schedule and execute batches, and is required by pipeline parallelism
78+ # to eliminate pipeline bubbles.
79+ self .batch_queue_size = self .model_executor .max_concurrent_batches
80+ self .batch_queue : Optional [queue .Queue [Tuple [Future [ModelRunnerOutput ],
81+ SchedulerOutput ]]] = None
82+ if self .batch_queue_size > 1 :
83+ logger .info ("Batch queue is enabled with size %d" ,
84+ self .batch_queue_size )
85+ self .batch_queue = queue .Queue (self .batch_queue_size )
86+
7287 def _initialize_kv_caches (self ,
7388 vllm_config : VllmConfig ) -> Tuple [int , int ]:
7489 start = time .time ()
@@ -135,7 +150,55 @@ def step(self) -> EngineCoreOutputs:
135150 scheduler_output = self .scheduler .schedule ()
136151 output = self .model_executor .execute_model (scheduler_output )
137152 engine_core_outputs = self .scheduler .update_from_output (
138- scheduler_output , output )
153+ scheduler_output , output ) # type: ignore
154+ return engine_core_outputs
155+
156+ def step_with_batch_queue (self ) -> Optional [EngineCoreOutputs ]:
157+ """Schedule and execute batches with the batch queue.
158+ Note that if nothing to output in this step, None is returned.
159+
160+ The execution flow is as follows:
161+ 1. Try to schedule a new batch if there are unscheduled requests
162+ and the job queue is not full. If a new batch is scheduled, directly
163+ return an empty engine core output. In other words, we won't check
164+ and return model outputs before the batch queue is full.
165+ 2. If there is no new scheduled batch, meaning that the batch queue
166+ is full or no other requests can be scheduled, we block until the first
167+ batch in the job queue is finished.
168+ 3. Update the scheduler from the output.
169+ """
170+ assert self .batch_queue is not None
171+
172+ engine_core_outputs = None
173+ scheduler_output = None
174+ # If there are unscheduled requests and the job queue
175+ # is not full, schedule a new batch. Note that this is not blocking.
176+ if (self .scheduler .get_num_unscheduled_requests () > 0
177+ and not self .batch_queue .full ()):
178+ scheduler_output = self .scheduler .schedule ()
179+ if scheduler_output .total_num_scheduled_tokens > 0 :
180+ future = self .model_executor .execute_model (scheduler_output )
181+ self .batch_queue .put_nowait (
182+ (future , scheduler_output )) # type: ignore
183+
184+ # If all requests are scheduled or the job queue is full,
185+ # block until the first batch in the job queue is finished.
186+ if (scheduler_output is None
187+ or scheduler_output .total_num_scheduled_tokens == 0 ):
188+ try :
189+ future , scheduler_output = self .batch_queue .get (
190+ timeout = POLLING_TIMEOUT_S )
191+ # Blocking until the first result is available.
192+ model_output = future .result ()
193+ self .batch_queue .task_done ()
194+ engine_core_outputs = self .scheduler .update_from_output (
195+ scheduler_output , model_output )
196+ except queue .Empty :
197+ # If the queue is empty (timeout at .get), return
198+ # an empty EngineCoreOutputs for logging.
199+ engine_core_outputs = EngineCoreOutputs (
200+ outputs = [], scheduler_stats = self .scheduler .make_stats ())
201+
139202 return engine_core_outputs
140203
141204 def shutdown (self ):
@@ -226,6 +289,9 @@ def signal_handler(signum, frame):
226289 def run_busy_loop (self ):
227290 """Core busy loop of the EngineCore."""
228291
292+ step_fn = (self .step
293+ if self .batch_queue is None else self .step_with_batch_queue )
294+
229295 # Loop until process is sent a SIGINT or SIGTERM
230296 while True :
231297 # 1) Poll the input queue until there is work to do.
@@ -249,10 +315,11 @@ def run_busy_loop(self):
249315 self ._handle_client_request (* req )
250316
251317 # 3) Step the engine core.
252- outputs = self . step ()
318+ outputs = step_fn ()
253319
254- # 5) Put EngineCoreOutputs into the output queue.
255- self .output_queue .put_nowait (outputs )
320+ # 4) Put EngineCoreOutputs into the output queue.
321+ if outputs is not None :
322+ self .output_queue .put_nowait (outputs )
256323
257324 def _handle_client_request (self , request_type : EngineCoreRequestType ,
258325 request : Any ) -> None :
0 commit comments