Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ async def build_async_engine_client_from_engine_args(

from vllm.v1.engine.async_llm import AsyncLLM
async_llm: Optional[AsyncLLM] = None
client_count = client_config.pop(
"client_count") if client_config else 1
client_index = client_config.pop(
"client_index") if client_config else 0
try:
Expand All @@ -208,6 +210,7 @@ async def build_async_engine_client_from_engine_args(
enable_log_requests=engine_args.enable_log_requests,
disable_log_stats=engine_args.disable_log_stats,
client_addresses=client_config,
client_count=client_count,
client_index=client_index)

# Don't keep the dummy data in memory
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
start_engine_loop: bool = True,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
client_addresses: Optional[dict[str, str]] = None,
client_count: int = 1,
client_index: int = 0,
) -> None:
"""
Expand Down Expand Up @@ -120,6 +121,7 @@ def __init__(
executor_class=executor_class,
log_stats=self.log_stats,
client_addresses=client_addresses,
client_count=client_count,
client_index=client_index,
)

Expand Down Expand Up @@ -156,6 +158,7 @@ def from_vllm_config(
enable_log_requests: bool = False,
disable_log_stats: bool = False,
client_addresses: Optional[dict[str, str]] = None,
client_count: int = 1,
client_index: int = 0,
disable_log_requests: bool = True, # Deprecated, will be removed
) -> "AsyncLLM":
Expand All @@ -176,6 +179,7 @@ def from_vllm_config(
log_stats=not disable_log_stats,
usage_context=usage_context,
client_addresses=client_addresses,
client_count=client_count,
client_index=client_index,
)

Expand Down
110 changes: 73 additions & 37 deletions vllm/v1/engine/coordinator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import multiprocessing
import time
import weakref
Expand Down Expand Up @@ -65,18 +66,14 @@ def __init__(self, parallel_config: ParallelConfig):

# Assume coordinator is colocated with front-end procs when not in
# either external or hybrid DP LB mode.
local_only = not (external_lb or hybrid_lb)
front_publish_address = get_engine_client_zmq_addr(
local_only=not external_lb and not hybrid_lb, host=host)
local_only=local_only, host=host)

local_only_eng = dp_size == parallel_config.data_parallel_size_local
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)

# When in external LB mode, load stats aren't published, only changes
# to request wave / running state, so we don't need to rate-limit the
# updates to the front-end proc(s).
min_stats_update_interval_ms = 0 if external_lb else 100

context = get_mp_context()
self.proc: multiprocessing.Process = context.Process(
target=DPCoordinatorProc.run_coordinator,
Expand All @@ -86,7 +83,6 @@ def __init__(self, parallel_config: ParallelConfig):
"front_publish_address": front_publish_address,
"back_output_address": back_output_address,
"back_publish_address": back_publish_address,
"min_stats_update_interval_ms": min_stats_update_interval_ms,
},
daemon=True)
self.proc.start()
Expand Down Expand Up @@ -125,10 +121,6 @@ def __init__(self,

self.stats_update_interval_ms = min_stats_update_interval_ms

self.current_wave = 0
self.engines_running = False
self.stats_changed = False

@staticmethod
def run_coordinator(
engine_count: int,
Expand All @@ -155,6 +147,16 @@ def process_input_socket(self, front_publish_address: str,

decoder = MsgpackDecoder(EngineCoreOutputs)

# For tracking request wave progression.
current_wave = 0
engines_running = False

# For tracking request counts for internal load-balancing.
stats_changed = False
last_stats_step = -1
last_stats_wave = -1
last_step_counts: Optional[list[list[int]]] = None

with make_zmq_socket(
path=front_publish_address, # IPC
ctx=self.ctx,
Expand All @@ -179,21 +181,33 @@ def process_input_socket(self, front_publish_address: str,
while True:
elapsed = int(time.time() * 1000) - last_publish_time
# Send at stats_update_interval_ms interval if the stats have
# changed, or otherwise every 4 seconds.
# changed, or otherwise every 5 seconds.
wait_for = (self.stats_update_interval_ms
if self.stats_changed else 4000)
events = poller.poll(timeout=max(0, wait_for - elapsed))
if stats_changed else 5000)

# Wait at least 50ms to ensure we've received all stats for
# the current step.
min_timeout = 50 if last_step_counts is None else 0

events = poller.poll(timeout=max(min_timeout, wait_for -
elapsed))
if not events:
# Poller timeout - publish current stats to front-ends.
engine_req_counts_list = self._get_engine_counts()
to_publish = (engine_req_counts_list, self.current_wave,
self.engines_running)
if last_step_counts is not None:
engine_req_counts_list = last_step_counts
last_step_counts = None
else:
engine_req_counts_list = self._get_engine_counts()
stats_changed = False

to_publish = (engine_req_counts_list, current_wave,
engines_running)
publish_front.send(msgspec.msgpack.encode(to_publish))
last_publish_time = int(time.time() * 1000)
self.stats_changed = False
continue

events = dict(events)
wave_state_changed = False

if publish_front in events:
buffer = publish_front.recv()
Expand All @@ -220,7 +234,7 @@ def process_input_socket(self, front_publish_address: str,
# current_wave
# we note that 0 is the wave number for the new
# engine
self.engines_running = False
engines_running = False
logger.info(
"DPCoordinator scaled up from %s to %s "
"engines", current_count, new_engine_count)
Expand All @@ -236,15 +250,15 @@ def process_input_socket(self, front_publish_address: str,
# engines are paused, so that we can wake the other
# engines.
engine_to_exclude, wave = decoded
if not self.engines_running:
if wave < self.current_wave:
if not engines_running:
if wave < current_wave:
# If the wave number is stale, ensure the message
# is handled by all the engines.
engine_to_exclude = None

self.engines_running = True
self.stats_changed = True
self._send_start_wave(publish_back, self.current_wave,
engines_running = True
wave_state_changed = True
self._send_start_wave(publish_back, current_wave,
engine_to_exclude)

if output_back in events:
Expand All @@ -262,36 +276,56 @@ def process_input_socket(self, front_publish_address: str,
# 1. Updated request load stats - update our local
# state with these.
stats = self.engines[eng_index].request_counts
stats_step = scheduler_stats.step_counter
stats_wave = scheduler_stats.current_wave
if (stats_wave > last_stats_wave
or stats_wave == last_stats_wave
and stats_step > last_stats_step):
if stats_changed:
last_step_counts = self._get_engine_counts(
do_copy=True)
last_stats_step = stats_step
last_stats_wave = stats_wave
elif stats_wave != last_stats_wave or (
stats_step != last_stats_step):
logger.warning(
"Received stats for out-of-order "
"step (%d, %d) from engine %d (expected "
"> (%d, %d))", stats_wave, stats_step,
eng_index, last_stats_wave, last_stats_step)
stats[0] = scheduler_stats.num_waiting_reqs
stats[1] = scheduler_stats.num_running_reqs
self.stats_changed = True
stats_changed = True

if (wave := outputs.wave_complete) is not None:
# 2. Notification from rank 0 engine that we've
# moved into the global paused state
# (engines_running==False).
if self.current_wave <= wave:
if current_wave <= wave:
new_wave = wave + 1
logger.debug("Moving DP wave from %d to %d.",
self.current_wave, new_wave)
self.current_wave = new_wave
self.engines_running = False
self.stats_changed = True
current_wave, new_wave)
current_wave = new_wave
engines_running = False
wave_state_changed = True
elif (wave := outputs.start_wave) is not None and (
wave > self.current_wave or
(wave == self.current_wave
and not self.engines_running)):
wave > current_wave or
(wave == current_wave and not engines_running)):
# 3. The engine received request for a non-current wave
# so we must ensure that other engines progress to the
# next wave (race condition handling).
logger.debug(
"Starting wave %d after notification of "
"stale wave request from engine.", wave)
self.current_wave = wave
self.engines_running = True
self.stats_changed = True
current_wave = wave
engines_running = True
wave_state_changed = True
self._send_start_wave(publish_back, wave, eng_index)

if wave_state_changed:
message = (None, current_wave, engines_running)
publish_front.send(msgspec.msgpack.encode(message))

@staticmethod
def _send_start_wave(socket: zmq.Socket, wave: int,
exclude_engine_index: Optional[int]):
Expand All @@ -304,6 +338,8 @@ def _send_start_wave(socket: zmq.Socket, wave: int,
socket.send_multipart(
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))

def _get_engine_counts(self) -> list[list[int]]:
def _get_engine_counts(self, do_copy=False) -> list[list[int]]:
"""Return list of [waiting, running] count lists for each engine."""
if do_copy:
return [copy.copy(e.request_counts) for e in self.engines]
return [e.request_counts for e in self.engines]
13 changes: 8 additions & 5 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ def __init__(
):
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
self.step_counter = 0
self.current_wave = 0
self.last_counts = (0, 0)

Expand Down Expand Up @@ -978,7 +978,9 @@ def _maybe_publish_request_counts(self):
counts = self.scheduler.get_request_counts()
if counts != self.last_counts:
self.last_counts = counts
stats = SchedulerStats(*counts)
stats = SchedulerStats(*counts,
step_counter=self.step_counter,
current_wave=self.current_wave)
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(scheduler_stats=stats)))

Expand Down Expand Up @@ -1020,15 +1022,16 @@ def run_busy_loop(self):
self.output_queue.put_nowait(
(client_index,
EngineCoreOutputs(wave_complete=self.current_wave)))
# Increment wave count and reset step counter.
self.current_wave += 1
self.step_counter = 0

def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:

# Optimization - only perform finish-sync all-reduce every 32 steps.
self.counter += 1
if self.counter != 32:
self.step_counter += 1
if self.step_counter % 32 != 0:
return True
self.counter = 0

return ParallelConfig.has_unfinished_dp(self.dp_group,
local_unfinished)
Expand Down
Loading