Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
9ca44ce
[V1] AsyncLLM data parallel WIP
njhill Feb 26, 2025
3f51611
Handle pausing loop
njhill Feb 27, 2025
d8c591e
More single-node updates
njhill Feb 27, 2025
65e225d
some cleanup
njhill Feb 27, 2025
5ce57b6
fix up utility methods
njhill Feb 27, 2025
a3f1102
revert config check
njhill Feb 27, 2025
a66fb01
fixes
njhill Feb 27, 2025
67672c2
cleanup
njhill Feb 27, 2025
cf52fbf
fixes
njhill Feb 27, 2025
a4ec81b
reconcile with LLMEngine DP in decoupled engine case
njhill Feb 27, 2025
292aa00
minor simplification
njhill Feb 27, 2025
4b62ffd
rework
njhill Feb 28, 2025
407c72e
class refactor
njhill Mar 1, 2025
31bf7ea
fix
njhill Mar 1, 2025
fde51ce
adjust core engine init
njhill Mar 1, 2025
d5a3e68
Merge remote-tracking branch 'refs/remotes/origin/main' into multi-en…
njhill Mar 3, 2025
6d89a1b
fix new typing
njhill Mar 3, 2025
448abd9
fix :facepalm:
njhill Mar 3, 2025
a1e513e
bind socket first
njhill Mar 3, 2025
50cf64c
do you have to let it linger
njhill Mar 3, 2025
f365998
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 3, 2025
b2571f0
add comments
njhill Mar 4, 2025
32c6f24
aggregate stats
njhill Mar 4, 2025
9c30cd7
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 4, 2025
672d07e
Fix test
njhill Mar 4, 2025
dea382b
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 5, 2025
d24a626
fix and minor cleanup
njhill Mar 5, 2025
cd03c80
Add CI test
njhill Mar 6, 2025
f1004b7
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 6, 2025
d3298fa
Some simplification and fixes
njhill Mar 6, 2025
74dde48
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 6, 2025
5fe1b75
address @markmc's stats suggestion
njhill Mar 6, 2025
648659f
address @tms's arg comment
njhill Mar 6, 2025
119d1ec
fix utility method breakage
njhill Mar 6, 2025
55328ee
rename AsyncMPClient output_processor to output_handler
njhill Mar 6, 2025
4f5330e
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 6, 2025
48770ec
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 7, 2025
d229f4d
Fix
njhill Mar 7, 2025
2f91cc4
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 15, 2025
518047a
Remove redundant logic related to removed stats aggregation
njhill Mar 13, 2025
cb2b099
Fixes
njhill Mar 15, 2025
ff1137a
Merge remote-tracking branch 'refs/remotes/origin/main' into multi-en…
njhill Mar 16, 2025
61f4fcb
fix issue from main merge
njhill Mar 16, 2025
44874c2
remove leftover unused field
njhill Mar 17, 2025
66fc582
Fix offline DP compatibility
njhill Mar 17, 2025
7764466
Add timeout to data_parallel.py
njhill Mar 17, 2025
51e8bf0
Merge remote-tracking branch 'refs/remotes/origin/main' into multi-en…
njhill Mar 17, 2025
f692c12
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 19, 2025
47b5e1c
Enable less-frequent all-reduce optimization
njhill Mar 20, 2025
f226139
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 20, 2025
af47920
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 20, 2025
693c521
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 20, 2025
6e131e3
clean distributed shutdown
njhill Mar 20, 2025
d9ac856
address misc loose-ends
njhill Mar 20, 2025
3abbdef
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 21, 2025
b18417e
further tweaks
njhill Mar 21, 2025
56b2b78
Merge remote-tracking branch 'refs/remotes/origin/main' into multi-en…
njhill Mar 25, 2025
05ab310
Additional debug
njhill Mar 25, 2025
5295c34
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 27, 2025
4f897b8
Address review comments on tests
njhill Mar 27, 2025
62f32ed
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 27, 2025
771ccf1
Fix env var fallback
njhill Mar 27, 2025
05a0e83
Fix test supports_v1 check
njhill Mar 27, 2025
bc41b13
Fix yapf :facepalm:
njhill Mar 27, 2025
ccecb42
Merge remote-tracking branch 'origin/main' into multi-engine
njhill Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,14 @@ steps:
- examples/offline_inference/rlhf.py
- examples/offline_inference/rlhf_colocate.py
- tests/examples/offline_inference/data_parallel.py
- tests/v1/test_async_llm_dp.py
commands:
# test with tp=2 and external_dp=2
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with internal dp
- python3 ../examples/offline_inference/data_parallel.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
Expand Down Expand Up @@ -514,7 +516,10 @@ steps:
- vllm/worker/worker.py
- vllm/worker/model_runner.py
- entrypoints/llm/test_collective_rpc.py
- tests/v1/test_async_llm_dp.py
- vllm/v1/engine/
commands:
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- VLLM_ENABLE_V1_MULTIPROCESSING=0 pytest -v -s entrypoints/llm/test_collective_rpc.py
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
Expand Down
22 changes: 15 additions & 7 deletions examples/offline_inference/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
--master-port=13345
"""
import os
from time import sleep

from vllm import LLM, SamplingParams
from vllm.utils import get_open_port
Expand All @@ -36,14 +37,13 @@
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
dp_master_port, GPUs_per_dp_rank):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
# set devices for each dp_rank
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
str(i)
for i in range(local_dp_rank * GPUs_per_dp_rank, (local_dp_rank + 1) *
GPUs_per_dp_rank))

# CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
# engine processes.

# Sample prompts.
prompts = [
Expand Down Expand Up @@ -90,6 +90,9 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")

# Give engines time to pause their processing loops before exiting.
sleep(1)


if __name__ == "__main__":
import argparse
Expand Down Expand Up @@ -152,8 +155,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
procs.append(proc)
exit_code = 0
for proc in procs:
proc.join()
if proc.exitcode:
proc.join(timeout=300)
if proc.exitcode is None:
print(f"Killing process {proc.pid} that "
f"didn't stop within 5 minutes.")
proc.kill()
exit_code = 1
elif proc.exitcode:
exit_code = proc.exitcode

exit(exit_code)
8 changes: 4 additions & 4 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,

core_client: SyncMPClient = client

result = core_client._call_utility("echo", "testarg")
result = core_client.call_utility("echo", "testarg")
assert result == "testarg"

with pytest.raises(Exception) as e_info:
core_client._call_utility("echo", None, "help!")
core_client.call_utility("echo", None, "help!")

assert str(e_info.value) == "Call to echo method failed: help!"

Expand Down Expand Up @@ -238,10 +238,10 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):

core_client: AsyncMPClient = client

result = await core_client._call_utility_async("echo", "testarg")
result = await core_client.call_utility_async("echo", "testarg")
assert result == "testarg"

with pytest.raises(Exception) as e_info:
await core_client._call_utility_async("echo", None, "help!")
await core_client.call_utility_async("echo", None, "help!")

assert str(e_info.value) == "Call to echo method failed: help!"
109 changes: 109 additions & 0 deletions tests/v1/test_async_llm_dp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0

import asyncio
import os
from contextlib import ExitStack
from typing import Optional

import pytest

from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import DPAsyncMPClient

engine_args = AsyncEngineArgs(
model="ibm-research/PowerMoE-3b",
enforce_eager=True,
disable_log_requests=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=int(os.getenv("DP_SIZE", 2)),
)

if not current_platform.supports_v1(engine_args.create_model_config()):
pytest.skip(reason="Requires V1-supporting platform.",
allow_module_level=True)


async def generate(engine: AsyncLLM,
request_id: str,
prompt: PromptType,
output_kind: RequestOutputKind,
max_tokens: int,
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
# Ensure generate doesn't complete too fast for cancellation test.
await asyncio.sleep(0.2)

count = 0
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True,
output_kind=output_kind,
temperature=0,
prompt_logprobs=prompt_logprobs)
async for out in engine.generate(request_id=request_id,
prompt=prompt,
sampling_params=sampling_params):

num_tokens = len(out.outputs[0].token_ids)
if output_kind == RequestOutputKind.DELTA:
count += num_tokens
else:
count = num_tokens

await asyncio.sleep(0.)

return count, request_id


@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_load(output_kind: RequestOutputKind):

with ExitStack() as after:

prompt = "This is a test of data parallel"

engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

NUM_REQUESTS = 100
NUM_EXPECTED_TOKENS = 10

request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

# Create concurrent requests.
tasks = []
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(engine, request_id, prompt, output_kind,
NUM_EXPECTED_TOKENS)))

# Confirm that we got all the EXPECTED tokens from the requests.
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)
for task in pending:
task.cancel()
for task in done:
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
f"{request_id} generated {num_generated_tokens} but "
f"expected {NUM_EXPECTED_TOKENS}")

assert not engine.output_processor.has_unfinished_requests()

# testing internals here which may break
core_client: DPAsyncMPClient = engine.engine_core
# the engines only synchronize stopping every N steps so
# allow a small amount of time here.
for _ in range(10):
if core_client.num_engines_running == 0:
break
await asyncio.sleep(0.5)

assert core_client.num_engines_running == 0
assert not core_client.reqs_in_flight
21 changes: 16 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, random_uuid, resolve_obj_by_qualname)
get_cpu_memory, get_open_port, random_uuid,
resolve_obj_by_qualname)

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
Expand Down Expand Up @@ -1389,6 +1390,8 @@ class ParallelConfig:
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
data_parallel_size: int = 1 # Number of data parallel groups.
data_parallel_rank: int = 0 # Rank of the data parallel group.
# Local rank of the data parallel group, defaults to global rank.
data_parallel_rank_local: Optional[int] = None
# IP of the data parallel master.
data_parallel_master_ip: str = "127.0.0.1"
data_parallel_master_port: int = 29500 # Port of the data parallel master.
Expand Down Expand Up @@ -1493,10 +1496,18 @@ def __post_init__(self) -> None:
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size

self.data_parallel_size = envs.VLLM_DP_SIZE
self.data_parallel_rank = envs.VLLM_DP_RANK
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
if self.data_parallel_size > 1:
# Data parallel was specified in the engine args.
self.data_parallel_master_port = get_open_port()
# TODO multi-node
else:
# Otherwise fall back to env vars (e.g. for offline SPMD case).
self.data_parallel_size = envs.VLLM_DP_SIZE
self.data_parallel_rank = envs.VLLM_DP_RANK
self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT

self.world_size_across_dp = self.world_size * self.data_parallel_size

if self.distributed_executor_backend == "external_launcher":
Expand Down
12 changes: 12 additions & 0 deletions vllm/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from torch.distributed import ProcessGroup, TCPStore
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
_shutdown_backend,
_unregister_process_group,
is_nccl_available)
from torch.distributed.rendezvous import rendezvous

Expand Down Expand Up @@ -333,3 +335,13 @@ def stateless_init_torch_distributed_process_group(
pg._register_backend(device, backend_type, backend_class)

return pg


def stateless_destroy_torch_distributed_process_group(
pg: ProcessGroup) -> None:
"""
Destroy ProcessGroup returned by
stateless_init_torch_distributed_process_group().
"""
_shutdown_backend(pg)
_unregister_process_group(pg.group_name)
10 changes: 10 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class EngineArgs:
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
data_parallel_size: int = 1
enable_expert_parallel: bool = False
max_parallel_loading_workers: Optional[int] = None
block_size: Optional[int] = None
Expand Down Expand Up @@ -442,6 +443,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=int,
default=EngineArgs.tensor_parallel_size,
help='Number of tensor parallel replicas.')
parser.add_argument('--data-parallel-size',
'-dp',
type=int,
default=EngineArgs.data_parallel_size,
help='Number of data parallel replicas. '
'MoE layers will be sharded according to the '
'product of the tensor-parallel-size and '
'data-parallel-size.')
parser.add_argument(
'--enable-expert-parallel',
action='store_true',
Expand Down Expand Up @@ -1359,6 +1368,7 @@ def create_engine_config(
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
data_parallel_size=self.data_parallel_size,
enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
Expand Down
8 changes: 8 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import hashlib
import os
import sys
import tempfile
from typing import TYPE_CHECKING, Any, Callable, Optional

Expand Down Expand Up @@ -95,6 +96,7 @@
VLLM_CUDART_SO_PATH: Optional[str] = None
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
VLLM_DP_RANK: int = 0
VLLM_DP_RANK_LOCAL: int = -1
VLLM_DP_SIZE: int = 1
VLLM_DP_MASTER_IP: str = ""
VLLM_DP_MASTER_PORT: int = 0
Expand Down Expand Up @@ -625,6 +627,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_DP_RANK":
lambda: int(os.getenv("VLLM_DP_RANK", "0")),

# Rank of the process in the data parallel setting.
# Defaults to VLLM_DP_RANK when not set.
"VLLM_DP_RANK_LOCAL":
lambda: int(
os.getenv("VLLM_DP_RANK_LOCAL", sys.modules[__name__].VLLM_DP_RANK)),

# World size of the data parallel setting
"VLLM_DP_SIZE":
lambda: int(os.getenv("VLLM_DP_SIZE", "1")),
Expand Down
14 changes: 9 additions & 5 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def get_open_port() -> int:
dp_port = envs.VLLM_DP_MASTER_PORT
while True:
port = _get_open_port()
if port >= dp_port and port < dp_port + 10:
if dp_port <= port < dp_port + 10:
continue
return port
return _get_open_port()
Expand Down Expand Up @@ -2176,19 +2176,23 @@ def make_zmq_socket(
if socket_type == zmq.constants.PULL:
socket.setsockopt(zmq.constants.RCVHWM, 0)
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
socket.connect(path)
socket.bind(path)
elif socket_type == zmq.constants.PUSH:
socket.setsockopt(zmq.constants.SNDHWM, 0)
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
socket.bind(path)
socket.connect(path)
else:
raise ValueError(f"Unknown Socket Type: {socket_type}")

return socket


@contextlib.contextmanager
def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
def zmq_socket_ctx(
path: str,
socket_type: Any,
linger: int = 0,
) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""

ctx = zmq.Context() # type: ignore[attr-defined]
Expand All @@ -2199,7 +2203,7 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
logger.debug("Got Keyboard Interrupt.")

finally:
ctx.destroy(linger=0)
ctx.destroy(linger=linger)


def is_in_ray_actor():
Expand Down
Loading