11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+ import copy
34import multiprocessing
45import time
56import weakref
@@ -65,18 +66,14 @@ def __init__(self, parallel_config: ParallelConfig):
6566
6667 # Assume coordinator is colocated with front-end procs when not in
6768 # either external or hybrid DP LB mode.
69+ local_only = not (external_lb or hybrid_lb )
6870 front_publish_address = get_engine_client_zmq_addr (
69- local_only = not external_lb and not hybrid_lb , host = host )
71+ local_only = local_only , host = host )
7072
7173 local_only_eng = dp_size == parallel_config .data_parallel_size_local
7274 back_publish_address = get_engine_client_zmq_addr (local_only_eng , host )
7375 back_output_address = get_engine_client_zmq_addr (local_only_eng , host )
7476
75- # When in external LB mode, load stats aren't published, only changes
76- # to request wave / running state, so we don't need to rate-limit the
77- # updates to the front-end proc(s).
78- min_stats_update_interval_ms = 0 if external_lb else 100
79-
8077 context = get_mp_context ()
8178 self .proc : multiprocessing .Process = context .Process (
8279 target = DPCoordinatorProc .run_coordinator ,
@@ -86,7 +83,6 @@ def __init__(self, parallel_config: ParallelConfig):
8683 "front_publish_address" : front_publish_address ,
8784 "back_output_address" : back_output_address ,
8885 "back_publish_address" : back_publish_address ,
89- "min_stats_update_interval_ms" : min_stats_update_interval_ms ,
9086 },
9187 daemon = True )
9288 self .proc .start ()
@@ -125,10 +121,6 @@ def __init__(self,
125121
126122 self .stats_update_interval_ms = min_stats_update_interval_ms
127123
128- self .current_wave = 0
129- self .engines_running = False
130- self .stats_changed = False
131-
132124 @staticmethod
133125 def run_coordinator (
134126 engine_count : int ,
@@ -155,6 +147,16 @@ def process_input_socket(self, front_publish_address: str,
155147
156148 decoder = MsgpackDecoder (EngineCoreOutputs )
157149
150+ # For tracking request wave progression.
151+ current_wave = 0
152+ engines_running = False
153+
154+ # For tracking request counts for internal load-balancing.
155+ stats_changed = False
156+ last_stats_step = - 1
157+ last_stats_wave = - 1
158+ last_step_counts : Optional [list [list [int ]]] = None
159+
158160 with make_zmq_socket (
159161 path = front_publish_address , # IPC
160162 ctx = self .ctx ,
@@ -191,21 +193,33 @@ def process_input_socket(self, front_publish_address: str,
191193 while True :
192194 elapsed = int (time .time () * 1000 ) - last_publish_time
193195 # Send at stats_update_interval_ms interval if the stats have
194- # changed, or otherwise every 4 seconds.
196+ # changed, or otherwise every 5 seconds.
195197 wait_for = (self .stats_update_interval_ms
196- if self .stats_changed else 4000 )
197- events = poller .poll (timeout = max (0 , wait_for - elapsed ))
198+ if stats_changed else 5000 )
199+
200+ # Wait at least 50ms to ensure we've received all stats for
201+ # the current step.
202+ min_timeout = 50 if last_step_counts is None else 0
203+
204+ events = poller .poll (timeout = max (min_timeout , wait_for -
205+ elapsed ))
198206 if not events :
199207 # Poller timeout - publish current stats to front-ends.
200- engine_req_counts_list = self ._get_engine_counts ()
201- to_publish = (engine_req_counts_list , self .current_wave ,
202- self .engines_running )
208+ if last_step_counts is not None :
209+ engine_req_counts_list = last_step_counts
210+ last_step_counts = None
211+ else :
212+ engine_req_counts_list = self ._get_engine_counts ()
213+ stats_changed = False
214+
215+ to_publish = (engine_req_counts_list , current_wave ,
216+ engines_running )
203217 publish_front .send (msgspec .msgpack .encode (to_publish ))
204218 last_publish_time = int (time .time () * 1000 )
205- self .stats_changed = False
206219 continue
207220
208221 events = dict (events )
222+ wave_state_changed = False
209223
210224 if publish_front in events :
211225 buffer = publish_front .recv ()
@@ -232,7 +246,7 @@ def process_input_socket(self, front_publish_address: str,
232246 # current_wave
233247 # we note that 0 is the wave number for the new
234248 # engine
235- self . engines_running = False
249+ engines_running = False
236250 logger .info (
237251 "DPCoordinator scaled up from %s to %s "
238252 "engines" , current_count , new_engine_count )
@@ -248,15 +262,15 @@ def process_input_socket(self, front_publish_address: str,
248262 # engines are paused, so that we can wake the other
249263 # engines.
250264 engine_to_exclude , wave = decoded
251- if not self . engines_running :
252- if wave < self . current_wave :
265+ if not engines_running :
266+ if wave < current_wave :
253267 # If the wave number is stale, ensure the message
254268 # is handled by all the engines.
255269 engine_to_exclude = None
256270
257- self . engines_running = True
258- self . stats_changed = True
259- self ._send_start_wave (publish_back , self . current_wave ,
271+ engines_running = True
272+ wave_state_changed = True
273+ self ._send_start_wave (publish_back , current_wave ,
260274 engine_to_exclude )
261275
262276 if output_back in events :
@@ -274,36 +288,56 @@ def process_input_socket(self, front_publish_address: str,
274288 # 1. Updated request load stats - update our local
275289 # state with these.
276290 stats = self .engines [eng_index ].request_counts
291+ stats_step = scheduler_stats .step_counter
292+ stats_wave = scheduler_stats .current_wave
293+ if (stats_wave > last_stats_wave
294+ or stats_wave == last_stats_wave
295+ and stats_step > last_stats_step ):
296+ if stats_changed :
297+ last_step_counts = self ._get_engine_counts (
298+ do_copy = True )
299+ last_stats_step = stats_step
300+ last_stats_wave = stats_wave
301+ elif stats_wave != last_stats_wave or (
302+ stats_step != last_stats_step ):
303+ logger .warning (
304+ "Received stats for out-of-order "
305+ "step (%d, %d) from engine %d (expected "
306+ "> (%d, %d))" , stats_wave , stats_step ,
307+ eng_index , last_stats_wave , last_stats_step )
277308 stats [0 ] = scheduler_stats .num_waiting_reqs
278309 stats [1 ] = scheduler_stats .num_running_reqs
279- self . stats_changed = True
310+ stats_changed = True
280311
281312 if (wave := outputs .wave_complete ) is not None :
282313 # 2. Notification from rank 0 engine that we've
283314 # moved into the global paused state
284315 # (engines_running==False).
285- if self . current_wave <= wave :
316+ if current_wave <= wave :
286317 new_wave = wave + 1
287318 logger .debug ("Moving DP wave from %d to %d." ,
288- self . current_wave , new_wave )
289- self . current_wave = new_wave
290- self . engines_running = False
291- self . stats_changed = True
319+ current_wave , new_wave )
320+ current_wave = new_wave
321+ engines_running = False
322+ wave_state_changed = True
292323 elif (wave := outputs .start_wave ) is not None and (
293- wave > self .current_wave or
294- (wave == self .current_wave
295- and not self .engines_running )):
324+ wave > current_wave or
325+ (wave == current_wave and not engines_running )):
296326 # 3. The engine received request for a non-current wave
297327 # so we must ensure that other engines progress to the
298328 # next wave (race condition handling).
299329 logger .debug (
300330 "Starting wave %d after notification of "
301331 "stale wave request from engine." , wave )
302- self . current_wave = wave
303- self . engines_running = True
304- self . stats_changed = True
332+ current_wave = wave
333+ engines_running = True
334+ wave_state_changed = True
305335 self ._send_start_wave (publish_back , wave , eng_index )
306336
337+ if wave_state_changed :
338+ message = (None , current_wave , engines_running )
339+ publish_front .send (msgspec .msgpack .encode (message ))
340+
307341 @staticmethod
308342 def _send_start_wave (socket : zmq .Socket , wave : int ,
309343 exclude_engine_index : Optional [int ]):
@@ -316,6 +350,8 @@ def _send_start_wave(socket: zmq.Socket, wave: int,
316350 socket .send_multipart (
317351 (EngineCoreRequestType .START_DP_WAVE .value , wave_encoded ))
318352
319- def _get_engine_counts (self ) -> list [list [int ]]:
353+ def _get_engine_counts (self , do_copy = False ) -> list [list [int ]]:
320354 """Return list of [waiting, running] count lists for each engine."""
355+ if do_copy :
356+ return [copy .copy (e .request_counts ) for e in self .engines ]
321357 return [e .request_counts for e in self .engines ]
0 commit comments