Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/v1/test_async_llm_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind):
# the engines only synchronize stopping every N steps so
# allow a small amount of time here.
for _ in range(10):
if core_client.num_engines_running == 0:
if not core_client.engines_running:
break
await asyncio.sleep(0.5)

assert core_client.num_engines_running == 0
assert not core_client.engines_running
assert not core_client.reqs_in_flight
15 changes: 12 additions & 3 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class EngineCoreRequest(
arrival_time: float
lora_request: Optional[LoRARequest]

# Used in DP case to indicate which wave of requests this is expected to
# belong to, to cover a race condition where the request is sent before
# a wave finished notification is received.
current_wave: int = 0


class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event."""
Expand Down Expand Up @@ -139,8 +144,12 @@ class EngineCoreOutputs(
utility_output: Optional[UtilityOutput] = None
finished_requests: Optional[set[str]] = None

# In DP case, used to signal that the engine is paused.
engine_paused: bool = False
# In DP case, used to signal that the current wave of requests
# has finished and the engines are paused.
wave_complete: Optional[int] = None
# In DP case, used to signal that a request was received for an
# "old" wave, so the next wave needs to be started in other engines.
start_wave: Optional[int] = None

def __post_init__(self):
if self.timestamp == 0.0:
Expand All @@ -154,7 +163,7 @@ class EngineCoreRequestType(enum.Enum):
"""
ADD = b'\x00'
ABORT = b'\x01'
START_DP = b'\x02'
START_DP_WAVE = b'\x02'
UTILITY = b'\x03'
# Sentinel used within EngineCoreProc.
EXECUTOR_FAILED = b'\x04'
63 changes: 42 additions & 21 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def __init__(

self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
self.global_unfinished_reqs = False
self.engines_running = False

# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
Expand Down Expand Up @@ -414,19 +414,15 @@ def _process_input_queue(self):
"""Exits when an engine step needs to be performed."""

waited = False
while not self.global_unfinished_reqs and not (
self.scheduler.has_requests()):
while not self.engines_running and not (self.scheduler.has_requests()):
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
logger.debug("EngineCore waiting for work.")
waited = True
req = self.input_queue.get()
self._handle_client_request(*req)

if waited:
logger.debug(
"EngineCore loop active - local unfinished: %s, finished: %s.",
self.scheduler.has_unfinished_requests(),
self.scheduler.has_finished_requests())
logger.debug("EngineCore loop active.")

# Handle any more client requests.
while not self.input_queue.empty():
Expand All @@ -450,10 +446,6 @@ def _handle_client_request(self, request_type: EngineCoreRequestType,
self.add_request(request)
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.START_DP:
if not self.global_unfinished_reqs:
logger.debug("EngineCore starting idle loop.")
self.global_unfinished_reqs = True
elif request_type == EngineCoreRequestType.UTILITY:
call_id, method_name, args = request
output = UtilityOutput(call_id)
Expand Down Expand Up @@ -552,9 +544,6 @@ def process_output_socket(self, output_path: str, engine_index: int):
socket.send_multipart(buffers, copy=False)


ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)


class DPEngineCoreProc(EngineCoreProc):
"""ZMQ-wrapper for running EngineCore in background process
in a data parallel context."""
Expand Down Expand Up @@ -591,7 +580,9 @@ def __init__(
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
tp_size))

self.local_dp_rank = local_dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
self.current_wave = 0

# Initialize the engine after setting up environment.
super().__init__(input_path, output_path, vllm_config, executor_class,
Expand All @@ -606,6 +597,31 @@ def shutdown(self):
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)

def add_request(self, request: EngineCoreRequest):
if request.current_wave != self.current_wave:
if request.current_wave > self.current_wave:
self.current_wave = request.current_wave
elif not self.engines_running:
# Request received for an already-completed wave, notify
# front-end that we need to start the next one.
self.output_queue.put_nowait(
EngineCoreOutputs(start_wave=self.current_wave))

super().add_request(request)

def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
if request_type == EngineCoreRequestType.START_DP_WAVE:
new_wave: int = request
if new_wave >= self.current_wave:
self.current_wave = new_wave
if not self.engines_running:
logger.debug("EngineCore starting idle loop for wave %d.",
new_wave)
self.engines_running = True
else:
super()._handle_client_request(request_type, request)

def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case."""

Expand All @@ -632,7 +648,7 @@ def run_busy_loop(self):
# up-to-date state is returned in the engine outputs.
self._process_engine_step()

if not self.global_unfinished_reqs:
if not self.engines_running:
# All engines are idle.
continue

Expand All @@ -641,18 +657,23 @@ def run_busy_loop(self):
self.execute_dummy_batch()

# 3) All-reduce operation to determine global unfinished reqs.
self.global_unfinished_reqs = self._has_global_unfinished_reqs(
self.engines_running = self._has_global_unfinished_reqs(
local_unfinished_reqs)

if not self.global_unfinished_reqs:
# Notify client that we are pausing the loop.
self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS)
if not self.engines_running:
if self.local_dp_rank == 0:
# Notify client that we are pausing the loop.
logger.debug("Wave %d finished, pausing engine loop.",
self.current_wave)
self.output_queue.put_nowait(
EngineCoreOutputs(wave_complete=self.current_wave))
self.current_wave += 1

def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:

# Optimization - only perform finish-sync all-reduce every 16 steps.
# Optimization - only perform finish-sync all-reduce every 24 steps.
self.counter += 1
if self.counter != 16:
if self.counter != 24:
return True
self.counter = 0

Expand Down
69 changes: 38 additions & 31 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,15 +792,12 @@ class DPAsyncMPClient(AsyncMPClient):
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
log_stats: bool):

self.num_engines_running = 0
self.current_wave = 0
self.engines_running = False
self.reqs_in_flight: dict[str, CoreEngine] = {}

super().__init__(vllm_config, executor_class, log_stats)

# Control message used for triggering dp idle mode loop.
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
*self.encoder.encode(None))

assert len(self.core_engines) > 1

def _init_core_engines(
Expand Down Expand Up @@ -829,23 +826,23 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
request.prompt = None

msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request))
request.current_wave = self.current_wave

chosen_engine = self.get_core_engine_for_request()
self.reqs_in_flight[request.request_id] = chosen_engine
chosen_engine.num_reqs_in_flight += 1
if self.num_engines_running >= len(self.core_engines):
await self._send_input_message(msg, chosen_engine)
else:

to_await = self._send_input(EngineCoreRequestType.ADD, request,
chosen_engine)
if not self.engines_running:
# Send request to chosen engine and dp start loop
# control message to all other engines.
self.num_engines_running += len(self.core_engines)
await asyncio.gather(*[
self._send_input_message(
msg if engine is chosen_engine else self.start_dp_msg,
engine) for engine in self.core_engines
])
self.engines_running = True
to_await = asyncio.gather(
to_await, # type: ignore[assignment]
*self._start_wave_coros(exclude_index=chosen_engine.index))

await to_await

self._ensure_output_queue_task()

Expand All @@ -860,21 +857,31 @@ async def process_engine_outputs(self: "DPAsyncMPClient",
if engine := self.reqs_in_flight.pop(req_id, None):
engine.num_reqs_in_flight -= 1

if outputs.engine_paused:
assert self.num_engines_running >= 1
self.num_engines_running -= 1
if not self.num_engines_running and self.reqs_in_flight:
# If there are requests in flight here, they must have
# been sent after the engines paused. We must make
# sure to start the other engines:
self.num_engines_running = len(self.core_engines)
coros = [
self._send_input_message(self.start_dp_msg, engine)
for engine in self.core_engines
if not engine.num_reqs_in_flight
]
if coros:
await asyncio.gather(*coros)
if outputs.wave_complete is not None:
# Current wave is complete, move to next wave number
# and mark engines as paused.
if self.current_wave <= outputs.wave_complete:
self.current_wave = outputs.wave_complete + 1
self.engines_running = False

elif outputs.start_wave is not None and (
outputs.start_wave > self.current_wave or
(outputs.start_wave == self.current_wave
and not self.engines_running)):
# Engine received request for a non-current wave so we must ensure
# that other engines progress to the next wave.
self.current_wave = outputs.start_wave
self.engines_running = True
await asyncio.gather(*self._start_wave_coros(
exclude_index=outputs.engine_index))

def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]:
logger.debug("Sending start DP wave %d.", self.current_wave)
return [
self._send_input(EngineCoreRequestType.START_DP_WAVE,
self.current_wave, engine)
for engine in self.core_engines if engine.index != exclude_index
]

async def abort_requests_async(self, request_ids: list[str]) -> None:
if not request_ids:
Expand Down