From 8177e2f02f12944dfec5d8c5dce6801c91320627 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 25 Jul 2025 14:48:25 +0100 Subject: [PATCH 1/4] [BugFix] Improve internal DP load balancing Signed-off-by: Nick Hill --- vllm/v1/engine/coordinator.py | 98 ++++++++++++++++++++++------------- vllm/v1/engine/core.py | 10 ++-- vllm/v1/engine/core_client.py | 30 ++++++----- vllm/v1/metrics/stats.py | 2 + 4 files changed, 85 insertions(+), 55 deletions(-) diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index fc45eea3a73c..5ebee31ebaed 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy import multiprocessing import time import weakref @@ -66,18 +67,14 @@ def __init__(self, parallel_config: ParallelConfig): # Assume coordinator is colocated with front-end procs when not in # either external or hybrid DP LB mode. + local_only = not (external_lb or hybrid_lb) front_publish_address = get_engine_client_zmq_addr( - local_only=not external_lb and not hybrid_lb, host=host) + local_only=local_only, host=host) local_only_eng = dp_size == parallel_config.data_parallel_size_local back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) back_output_address = get_engine_client_zmq_addr(local_only_eng, host) - # When in external LB mode, load stats aren't published, only changes - # to request wave / running state, so we don't need to rate-limit the - # updates to the front-end proc(s). - min_stats_update_interval_ms = 0 if external_lb else 100 - context = get_mp_context() self.proc: multiprocessing.Process = context.Process( target=DPCoordinatorProc.run_coordinator, @@ -87,7 +84,6 @@ def __init__(self, parallel_config: ParallelConfig): "front_publish_address": front_publish_address, "back_output_address": back_output_address, "back_publish_address": back_publish_address, - "min_stats_update_interval_ms": min_stats_update_interval_ms, }, daemon=True) self.proc.start() @@ -126,10 +122,6 @@ def __init__(self, self.stats_update_interval_ms = min_stats_update_interval_ms - self.current_wave = 0 - self.engines_running = False - self.stats_changed = False - @staticmethod def run_coordinator( engine_count: int, @@ -156,6 +148,13 @@ def process_input_socket(self, front_publish_address: str, decoder = MsgpackDecoder(EngineCoreOutputs) + current_wave = 0 + engines_running = False + + stats_changed = False + last_stats_step = -1 + last_step_counts: Optional[list[list[int]]] = None + with make_zmq_socket( path=front_publish_address, # IPC ctx=self.ctx, @@ -180,21 +179,33 @@ def process_input_socket(self, front_publish_address: str, while True: elapsed = int(time.time() * 1000) - last_publish_time # Send at stats_update_interval_ms interval if the stats have - # changed, or otherwise every 4 seconds. + # changed, or otherwise every 5 seconds. wait_for = (self.stats_update_interval_ms - if self.stats_changed else 4000) - events = poller.poll(timeout=max(0, wait_for - elapsed)) + if stats_changed else 5000) + + # Wait at least 50ms to ensure we've received all stats for + # the current step. + min_timeout = 50 if last_step_counts is None else 0 + + events = poller.poll(timeout=max(min_timeout, wait_for - + elapsed)) if not events: # Poller timeout - publish current stats to front-ends. - engine_req_counts_list = self._get_engine_counts() - to_publish = (engine_req_counts_list, self.current_wave, - self.engines_running) + if last_step_counts is not None: + engine_req_counts_list = last_step_counts + last_step_counts = None + else: + engine_req_counts_list = self._get_engine_counts() + stats_changed = False + + to_publish = (engine_req_counts_list, current_wave, + engines_running) publish_front.send(msgspec.msgpack.encode(to_publish)) last_publish_time = int(time.time() * 1000) - self.stats_changed = False continue events = dict(events) + wave_state_changed = False if publish_front in events: buffer = publish_front.recv() @@ -221,7 +232,7 @@ def process_input_socket(self, front_publish_address: str, # current_wave # we note that 0 is the wave number for the new # engine - self.engines_running = False + engines_running = False logger.info( "DPCoordinator scaled up from %s to %s " "engines", current_count, new_engine_count) @@ -237,15 +248,15 @@ def process_input_socket(self, front_publish_address: str, # engines are paused, so that we can wake the other # engines. engine_to_exclude, wave = decoded - if not self.engines_running: - if wave < self.current_wave: + if not engines_running: + if wave < current_wave: # If the wave number is stale, ensure the message # is handled by all the engines. engine_to_exclude = None - self.engines_running = True - self.stats_changed = True - self._send_start_wave(publish_back, self.current_wave, + engines_running = True + wave_state_changed = True + self._send_start_wave(publish_back, current_wave, engine_to_exclude) if output_back in events: @@ -263,36 +274,47 @@ def process_input_socket(self, front_publish_address: str, # 1. Updated request load stats - update our local # state with these. stats = self.engines[eng_index].request_counts + stats_step = scheduler_stats.step_counter + if stats_changed and stats_step != last_stats_step: + last_step_counts = self._get_engine_counts( + do_copy=True) + elif stats_step < last_stats_step: + logger.warning("Received stats for out-of-order " + "step from engine {eng_index}") stats[0] = scheduler_stats.num_waiting_reqs stats[1] = scheduler_stats.num_running_reqs - self.stats_changed = True + last_stats_step = stats_step + stats_changed = True if (wave := outputs.wave_complete) is not None: # 2. Notification from rank 0 engine that we've # moved into the global paused state # (engines_running==False). - if self.current_wave <= wave: + if current_wave <= wave: new_wave = wave + 1 logger.debug("Moving DP wave from %d to %d.", - self.current_wave, new_wave) - self.current_wave = new_wave - self.engines_running = False - self.stats_changed = True + current_wave, new_wave) + current_wave = new_wave + engines_running = False + wave_state_changed = True elif (wave := outputs.start_wave) is not None and ( - wave > self.current_wave or - (wave == self.current_wave - and not self.engines_running)): + wave > current_wave or + (wave == current_wave and not engines_running)): # 3. The engine received request for a non-current wave # so we must ensure that other engines progress to the # next wave (race condition handling). logger.debug( "Starting wave %d after notification of " "stale wave request from engine.", wave) - self.current_wave = wave - self.engines_running = True - self.stats_changed = True + current_wave = wave + engines_running = True + wave_state_changed = True self._send_start_wave(publish_back, wave, eng_index) + if wave_state_changed: + message = (None, current_wave, engines_running) + publish_front.send(msgspec.msgpack.encode(message)) + @staticmethod def _send_start_wave(socket: zmq.Socket, wave: int, exclude_engine_index: Optional[int]): @@ -305,6 +327,8 @@ def _send_start_wave(socket: zmq.Socket, wave: int, socket.send_multipart( (EngineCoreRequestType.START_DP_WAVE.value, wave_encoded)) - def _get_engine_counts(self) -> list[list[int]]: + def _get_engine_counts(self, do_copy=False) -> list[list[int]]: """Return list of [waiting, running] count lists for each engine.""" + if do_copy: + return [copy.copy(e.request_counts) for e in self.engines] return [e.request_counts for e in self.engines] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 88c511606d7c..a1642a999033 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -874,7 +874,7 @@ def __init__( # Counts forward-passes of the model so that we can synchronize # finished with DP peers every N steps. - self.counter = 0 + self.step_counter = 0 self.current_wave = 0 self.last_counts = (0, 0) @@ -954,7 +954,7 @@ def _maybe_publish_request_counts(self): counts = self.scheduler.get_request_counts() if counts != self.last_counts: self.last_counts = counts - stats = SchedulerStats(*counts) + stats = SchedulerStats(*counts, step_counter=self.step_counter) self.output_queue.put_nowait( (-1, EngineCoreOutputs(scheduler_stats=stats))) @@ -1001,10 +1001,10 @@ def run_busy_loop(self): def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: # Optimization - only perform finish-sync all-reduce every 32 steps. - self.counter += 1 - if self.counter != 32: + self.step_counter += 1 + if self.step_counter != 32: return True - self.counter = 0 + self.step_counter = 0 return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 69ae3690d00e..7f712a0acb81 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -970,7 +970,12 @@ async def run_engine_stats_update_task(): counts, wave, running = msgspec.msgpack.decode(buf) self.current_wave = wave self.engines_running = running - self.lb_engines = counts[count_slice] + if counts is not None: + sliced_counts = counts[count_slice] + self.lb_engines = sliced_counts + #TODO TBD whether to keep this debug log + logger.debug("Received counts: %s (%s)", + sliced_counts, count_slice) resources.stats_update_task = asyncio.create_task( run_engine_stats_update_task()) @@ -1019,27 +1024,26 @@ def __init__(self, def get_core_engine_for_request( self, request: EngineCoreRequest) -> EngineIdentity: # Engines are in rank order. + current_counts = self.lb_engines if (eng_index := request.data_parallel_rank) is None: - if not self.lb_engines: + if not current_counts: return self.core_engine # TODO use P2C alg for larger DP sizes - num_engines = len(self.lb_engines) - min_counts = [sys.maxsize, sys.maxsize] + num_engines = len(current_counts) + min_score = sys.maxsize eng_index = 0 for i in range(num_engines): # Start from client_index to help with balancing when engines # are empty. idx = (self.client_index + i) % num_engines - counts = self.lb_engines[idx] - if counts < min_counts: - min_counts = counts + waiting, running = current_counts[idx] + score = waiting * 4 + running + if score < min_score: + min_score = score eng_index = idx - # Adjust local counts for better balancing between stats updates - # from the coordinator (which happen every 100ms). - if min_counts[0]: - min_counts[0] += 1 - else: - min_counts[1] += 1 + # Increment local waiting count for better balancing between stats + # updates from the coordinator (which happen every 100ms). + current_counts[eng_index][0] += 1 chosen_engine = self.core_engines[eng_index] # Record which engine is chosen for this request, to handle aborts. diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 1eb10ccb6c49..5e2f9e739280 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -33,6 +33,8 @@ class SchedulerStats: num_running_reqs: int = 0 num_waiting_reqs: int = 0 + step_counter: int = 0 + kv_cache_usage: float = 0.0 prefix_cache_stats: PrefixCacheStats = field( From d7ab219cfb63b17be654d22a2f4a9dcf75fa3ee4 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 28 Jul 2025 13:22:46 +0100 Subject: [PATCH 2/4] improvements, particularly for multi-api-server case Signed-off-by: Nick Hill --- vllm/entrypoints/openai/api_server.py | 3 +++ vllm/v1/engine/async_llm.py | 4 ++++ vllm/v1/engine/coordinator.py | 24 +++++++++++++++++------- vllm/v1/engine/core.py | 9 ++++++--- vllm/v1/engine/core_client.py | 24 ++++++++++++++++-------- vllm/v1/metrics/stats.py | 1 + vllm/v1/utils.py | 1 + 7 files changed, 48 insertions(+), 18 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 8540d25d4e94..413b5c670822 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -200,6 +200,8 @@ async def build_async_engine_client_from_engine_args( from vllm.v1.engine.async_llm import AsyncLLM async_llm: Optional[AsyncLLM] = None + client_count = client_config.pop( + "client_count") if client_config else 1 client_index = client_config.pop( "client_index") if client_config else 0 try: @@ -209,6 +211,7 @@ async def build_async_engine_client_from_engine_args( disable_log_requests=engine_args.disable_log_requests, disable_log_stats=engine_args.disable_log_stats, client_addresses=client_config, + client_count=client_count, client_index=client_index) # Don't keep the dummy data in memory diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 02cb80197fa4..1e232e4099e5 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -56,6 +56,7 @@ def __init__( start_engine_loop: bool = True, stat_loggers: Optional[list[StatLoggerFactory]] = None, client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, client_index: int = 0, ) -> None: """ @@ -119,6 +120,7 @@ def __init__( executor_class=executor_class, log_stats=self.log_stats, client_addresses=client_addresses, + client_count=client_count, client_index=client_index, ) @@ -150,6 +152,7 @@ def from_vllm_config( disable_log_requests: bool = False, disable_log_stats: bool = False, client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, client_index: int = 0, ) -> "AsyncLLM": if not envs.VLLM_USE_V1: @@ -169,6 +172,7 @@ def from_vllm_config( log_stats=not disable_log_stats, usage_context=usage_context, client_addresses=client_addresses, + client_count=client_count, client_index=client_index, ) diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 5ebee31ebaed..2233a5f28a0d 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -153,6 +153,7 @@ def process_input_socket(self, front_publish_address: str, stats_changed = False last_stats_step = -1 + last_stats_wave = -1 last_step_counts: Optional[list[list[int]]] = None with make_zmq_socket( @@ -275,15 +276,24 @@ def process_input_socket(self, front_publish_address: str, # state with these. stats = self.engines[eng_index].request_counts stats_step = scheduler_stats.step_counter - if stats_changed and stats_step != last_stats_step: - last_step_counts = self._get_engine_counts( - do_copy=True) - elif stats_step < last_stats_step: - logger.warning("Received stats for out-of-order " - "step from engine {eng_index}") + stats_wave = scheduler_stats.current_wave + if (stats_wave > last_stats_wave + or stats_wave == last_stats_wave + and stats_step > last_stats_step): + if stats_changed: + last_step_counts = self._get_engine_counts( + do_copy=True) + last_stats_step = stats_step + last_stats_wave = stats_wave + elif stats_wave != last_stats_wave or ( + stats_step != last_stats_step): + logger.warning( + "Received stats for out-of-order " + "step (%d, %d) from engine %d (expected " + "> (%d, %d))", stats_wave, stats_step, + eng_index, last_stats_wave, last_stats_step) stats[0] = scheduler_stats.num_waiting_reqs stats[1] = scheduler_stats.num_running_reqs - last_stats_step = stats_step stats_changed = True if (wave := outputs.wave_complete) is not None: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index a1642a999033..1b476b40e54f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -954,7 +954,9 @@ def _maybe_publish_request_counts(self): counts = self.scheduler.get_request_counts() if counts != self.last_counts: self.last_counts = counts - stats = SchedulerStats(*counts, step_counter=self.step_counter) + stats = SchedulerStats(*counts, + step_counter=self.step_counter, + current_wave=self.current_wave) self.output_queue.put_nowait( (-1, EngineCoreOutputs(scheduler_stats=stats))) @@ -996,15 +998,16 @@ def run_busy_loop(self): self.output_queue.put_nowait( (client_index, EngineCoreOutputs(wave_complete=self.current_wave))) + # Increment wave count and reset step counter. self.current_wave += 1 + self.step_counter = 0 def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: # Optimization - only perform finish-sync all-reduce every 32 steps. self.step_counter += 1 - if self.step_counter != 32: + if self.step_counter % 32 != 0: return True - self.step_counter = 0 return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 7f712a0acb81..8972700d596d 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -84,11 +84,12 @@ def make_async_mp_client( executor_class: type[Executor], log_stats: bool, client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, client_index: int = 0, ) -> "MPClient": parallel_config = vllm_config.parallel_config client_args = (vllm_config, executor_class, log_stats, - client_addresses, client_index) + client_addresses, client_count, client_index) if parallel_config.data_parallel_size > 1: if parallel_config.data_parallel_external_lb: # External load balancer - client per DP rank. @@ -673,6 +674,7 @@ def __init__(self, executor_class: type[Executor], log_stats: bool, client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, client_index: int = 0): super().__init__( asyncio_mode=True, @@ -870,11 +872,12 @@ def __init__(self, executor_class: type[Executor], log_stats: bool, client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, client_index: int = 0): self.current_wave = 0 super().__init__(vllm_config, executor_class, log_stats, - client_addresses, client_index) + client_addresses, client_count, client_index) # List of [waiting, running] pair per engine. # Used only by DPLBAsyncMPClient subclass. @@ -973,9 +976,8 @@ async def run_engine_stats_update_task(): if counts is not None: sliced_counts = counts[count_slice] self.lb_engines = sliced_counts - #TODO TBD whether to keep this debug log - logger.debug("Received counts: %s (%s)", - sliced_counts, count_slice) + logger.debug("Received counts: %s (%s)", sliced_counts, + count_slice) resources.stats_update_task = asyncio.create_task( run_engine_stats_update_task()) @@ -1011,16 +1013,22 @@ def __init__(self, executor_class: type[Executor], log_stats: bool, client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, client_index: int = 0): + self.client_count = client_count + # To route aborts to the correct engine. self.reqs_in_flight: dict[str, EngineIdentity] = {} super().__init__(vllm_config, executor_class, log_stats, - client_addresses, client_index) + client_addresses, client_count, client_index) assert len(self.core_engines) > 1 + self.eng_start_index = (len(self.core_engines) * + self.client_index) // client_count + def get_core_engine_for_request( self, request: EngineCoreRequest) -> EngineIdentity: # Engines are in rank order. @@ -1035,7 +1043,7 @@ def get_core_engine_for_request( for i in range(num_engines): # Start from client_index to help with balancing when engines # are empty. - idx = (self.client_index + i) % num_engines + idx = (self.eng_start_index + i) % num_engines waiting, running = current_counts[idx] score = waiting * 4 + running if score < min_score: @@ -1043,7 +1051,7 @@ def get_core_engine_for_request( eng_index = idx # Increment local waiting count for better balancing between stats # updates from the coordinator (which happen every 100ms). - current_counts[eng_index][0] += 1 + current_counts[eng_index][0] += self.client_count chosen_engine = self.core_engines[eng_index] # Record which engine is chosen for this request, to handle aborts. diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 5e2f9e739280..d6528c9596ed 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -34,6 +34,7 @@ class SchedulerStats: num_waiting_reqs: int = 0 step_counter: int = 0 + current_wave: int = 0 kv_cache_usage: float = 0.0 diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index bb5a36f38386..40c186470374 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -154,6 +154,7 @@ def __init__( client_config = { "input_address": in_addr, "output_address": out_addr, + "client_count": num_servers, "client_index": i } if stats_update_address is not None: From a53327ad7dc93c9268ca561683a2bfadeb0e32e9 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 28 Jul 2025 15:18:23 +0100 Subject: [PATCH 3/4] add some comments Signed-off-by: Nick Hill --- vllm/v1/engine/coordinator.py | 2 ++ vllm/v1/metrics/stats.py | 1 + 2 files changed, 3 insertions(+) diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 2233a5f28a0d..c7ffc1eb8161 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -148,9 +148,11 @@ def process_input_socket(self, front_publish_address: str, decoder = MsgpackDecoder(EngineCoreOutputs) + # For tracking request wave progression. current_wave = 0 engines_running = False + # For tracking request counts for internal load-balancing. stats_changed = False last_stats_step = -1 last_stats_wave = -1 diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index d6528c9596ed..9a80460261e0 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -33,6 +33,7 @@ class SchedulerStats: num_running_reqs: int = 0 num_waiting_reqs: int = 0 + # These are used for internal DP load-balancing. step_counter: int = 0 current_wave: int = 0 From 9246ac817705f6eaebb3879455b554b536264eb4 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 1 Aug 2025 19:45:07 +0100 Subject: [PATCH 4/4] linting fix Signed-off-by: Nick Hill --- vllm/v1/engine/async_llm.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index c61cdefab51b..45f450291ab6 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -150,17 +150,17 @@ def __init__( "Use `enable_log_requests` instead."), ) def from_vllm_config( - cls, - vllm_config: VllmConfig, - start_engine_loop: bool = True, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[list[StatLoggerFactory]] = None, - enable_log_requests: bool = False, - disable_log_stats: bool = False, - client_addresses: Optional[dict[str, str]] = None, - client_count: int = 1, - client_index: int = 0, - disable_log_requests: bool = True, # Deprecated, will be removed + cls, + vllm_config: VllmConfig, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[list[StatLoggerFactory]] = None, + enable_log_requests: bool = False, + disable_log_stats: bool = False, + client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, + client_index: int = 0, + disable_log_requests: bool = True, # Deprecated, will be removed ) -> "AsyncLLM": if not envs.VLLM_USE_V1: raise ValueError(