Skip to content

Commit bdce64f

Browse files
authored
[V1] Support DP with Ray (#18779)
1 parent 9e6f61e commit bdce64f

File tree

10 files changed

+539
-108
lines changed

10 files changed

+539
-108
lines changed

requirements/test.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ vector_quantize_pytorch # required for minicpmo_26 test
1717
vocos # required for minicpmo_26 test
1818
peft
1919
pqdm
20-
ray[cgraph]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests
20+
ray[cgraph,default]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests
2121
sentence-transformers # required for embedding tests
2222
soundfile # required for audio tests
2323
jiwer # required for audio tests

requirements/test.txt

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@ aiohappyeyeballs==2.4.3
1010
# via aiohttp
1111
aiohttp==3.10.11
1212
# via
13+
# aiohttp-cors
1314
# datasets
1415
# fsspec
1516
# lm-eval
17+
# ray
18+
aiohttp-cors==0.8.1
19+
# via ray
1620
aiosignal==1.3.1
1721
# via
1822
# aiohttp
@@ -57,6 +61,8 @@ bounded-pool-executor==0.0.3
5761
# via pqdm
5862
buildkite-test-collector==0.1.9
5963
# via -r requirements/test.in
64+
cachetools==5.5.2
65+
# via google-auth
6066
certifi==2024.8.30
6167
# via
6268
# httpcore
@@ -81,6 +87,8 @@ colorama==0.4.6
8187
# sacrebleu
8288
# schemathesis
8389
# tqdm-multiprocess
90+
colorful==0.5.6
91+
# via ray
8492
contourpy==1.3.0
8593
# via matplotlib
8694
cramjam==2.9.0
@@ -108,6 +116,8 @@ dill==0.3.8
108116
# evaluate
109117
# lm-eval
110118
# multiprocess
119+
distlib==0.3.9
120+
# via virtualenv
111121
dnspython==2.7.0
112122
# via email-validator
113123
docopt==0.6.2
@@ -143,6 +153,7 @@ filelock==3.16.1
143153
# ray
144154
# torch
145155
# transformers
156+
# virtualenv
146157
fonttools==4.54.1
147158
# via matplotlib
148159
fqdn==1.5.1
@@ -165,8 +176,16 @@ genai-perf==0.0.8
165176
# via -r requirements/test.in
166177
genson==1.3.0
167178
# via datamodel-code-generator
179+
google-api-core==2.24.2
180+
# via opencensus
181+
google-auth==2.40.2
182+
# via google-api-core
183+
googleapis-common-protos==1.70.0
184+
# via google-api-core
168185
graphql-core==3.2.6
169186
# via hypothesis-graphql
187+
grpcio==1.71.0
188+
# via ray
170189
h11==0.14.0
171190
# via httpcore
172191
harfile==0.3.0
@@ -392,6 +411,10 @@ nvidia-nvjitlink-cu12==12.8.61
392411
# torch
393412
nvidia-nvtx-cu12==12.8.55
394413
# via torch
414+
opencensus==0.11.4
415+
# via ray
416+
opencensus-context==0.1.3
417+
# via opencensus
395418
opencv-python-headless==4.11.0.86
396419
# via
397420
# -r requirements/test.in
@@ -445,6 +468,7 @@ platformdirs==4.3.6
445468
# via
446469
# black
447470
# pooch
471+
# virtualenv
448472
plotly==5.24.1
449473
# via genai-perf
450474
pluggy==1.5.0
@@ -457,10 +481,17 @@ portalocker==2.10.1
457481
# via sacrebleu
458482
pqdm==0.2.0
459483
# via -r requirements/test.in
484+
prometheus-client==0.22.0
485+
# via ray
460486
propcache==0.2.0
461487
# via yarl
488+
proto-plus==1.26.1
489+
# via google-api-core
462490
protobuf==5.28.3
463491
# via
492+
# google-api-core
493+
# googleapis-common-protos
494+
# proto-plus
464495
# ray
465496
# tensorizer
466497
psutil==6.1.0
@@ -470,10 +501,18 @@ psutil==6.1.0
470501
# tensorizer
471502
py==1.11.0
472503
# via pytest-forked
504+
py-spy==0.4.0
505+
# via ray
473506
pyarrow==18.0.0
474507
# via
475508
# datasets
476509
# genai-perf
510+
pyasn1==0.6.1
511+
# via
512+
# pyasn1-modules
513+
# rsa
514+
pyasn1-modules==0.4.2
515+
# via google-auth
477516
pybind11==2.13.6
478517
# via lm-eval
479518
pycparser==2.22
@@ -486,6 +525,7 @@ pydantic==2.11.5
486525
# datamodel-code-generator
487526
# mistral-common
488527
# mteb
528+
# ray
489529
pydantic-core==2.33.2
490530
# via pydantic
491531
pygments==2.18.0
@@ -573,6 +613,7 @@ requests==2.32.3
573613
# buildkite-test-collector
574614
# datasets
575615
# evaluate
616+
# google-api-core
576617
# huggingface-hub
577618
# lm-eval
578619
# mistral-common
@@ -601,6 +642,8 @@ rpds-py==0.20.1
601642
# via
602643
# jsonschema
603644
# referencing
645+
rsa==4.9.1
646+
# via google-auth
604647
runai-model-streamer==0.11.0
605648
# via -r requirements/test.in
606649
runai-model-streamer-s3==0.11.0
@@ -648,9 +691,12 @@ shellingham==1.5.4
648691
six==1.16.0
649692
# via
650693
# junit-xml
694+
# opencensus
651695
# python-dateutil
652696
# rfc3339-validator
653697
# rouge-score
698+
smart-open==7.1.0
699+
# via ray
654700
sniffio==1.3.1
655701
# via
656702
# anyio
@@ -801,6 +847,8 @@ urllib3==2.2.3
801847
# tritonclient
802848
vector-quantize-pytorch==1.21.2
803849
# via -r requirements/test.in
850+
virtualenv==20.31.2
851+
# via ray
804852
vocos==0.1.0
805853
# via -r requirements/test.in
806854
webcolors==24.11.1
@@ -809,6 +857,8 @@ werkzeug==3.1.3
809857
# via schemathesis
810858
word2number==1.1
811859
# via lm-eval
860+
wrapt==1.17.2
861+
# via smart-open
812862
xxhash==3.5.0
813863
# via
814864
# datasets

tests/v1/test_async_llm_dp.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,22 @@ async def generate(engine: AsyncLLM,
5959

6060

6161
@pytest.mark.parametrize(
62-
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
62+
"output_kind",
63+
[
64+
RequestOutputKind.DELTA,
65+
RequestOutputKind.FINAL_ONLY,
66+
],
67+
)
68+
@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"])
6369
@pytest.mark.asyncio
64-
async def test_load(output_kind: RequestOutputKind):
70+
async def test_load(output_kind: RequestOutputKind,
71+
data_parallel_backend: str):
6572

6673
with ExitStack() as after:
6774

6875
prompt = "This is a test of data parallel"
6976

77+
engine_args.data_parallel_backend = data_parallel_backend
7078
engine = AsyncLLM.from_engine_args(engine_args)
7179
after.callback(engine.shutdown)
7280

@@ -82,7 +90,6 @@ async def test_load(output_kind: RequestOutputKind):
8290
asyncio.create_task(
8391
generate(engine, request_id, prompt, output_kind,
8492
NUM_EXPECTED_TOKENS)))
85-
8693
# Confirm that we got all the EXPECTED tokens from the requests.
8794
done, pending = await asyncio.wait(tasks,
8895
return_when=asyncio.FIRST_EXCEPTION)

vllm/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,6 +1742,8 @@ class ParallelConfig:
17421742
"""Port for data parallel messaging."""
17431743
data_parallel_master_port: int = 29500
17441744
"""Port of the data parallel master."""
1745+
data_parallel_backend: str = "mp"
1746+
"""Backend to use for data parallel, either "mp" or "ray"."""
17451747
enable_expert_parallel: bool = False
17461748
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
17471749
max_parallel_loading_workers: Optional[int] = None
@@ -1911,6 +1913,10 @@ def __post_init__(self) -> None:
19111913
"please install Ray with `pip install "
19121914
"ray`.") from ray_utils.ray_import_err
19131915
backend = "ray"
1916+
elif self.data_parallel_backend == "ray":
1917+
logger.info("Using ray distributed inference because "
1918+
"data_parallel_backend is ray")
1919+
backend = "ray"
19141920
elif ray_found:
19151921
if self.placement_group:
19161922
backend = "ray"

vllm/engine/arg_utils.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from vllm.transformers_utils.utils import check_gguf_file
4040
from vllm.usage.usage_lib import UsageContext
4141
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
42-
GiB_bytes, is_in_ray_actor)
42+
GiB_bytes, get_ip, is_in_ray_actor)
4343

4444
# yapf: enable
4545

@@ -292,6 +292,7 @@ class EngineArgs:
292292
data_parallel_size_local: Optional[int] = None
293293
data_parallel_address: Optional[str] = None
294294
data_parallel_rpc_port: Optional[int] = None
295+
data_parallel_backend: str = ParallelConfig.data_parallel_backend
295296
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
296297
max_parallel_loading_workers: Optional[
297298
int] = ParallelConfig.max_parallel_loading_workers
@@ -624,6 +625,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
624625
type=int,
625626
help='Port for data parallel RPC '
626627
'communication.')
628+
parallel_group.add_argument('--data-parallel-backend',
629+
'-dpb',
630+
type=str,
631+
default='mp',
632+
help='Backend for data parallel, either '
633+
'"mp" or "ray".')
627634
parallel_group.add_argument(
628635
"--enable-expert-parallel",
629636
**parallel_kwargs["enable_expert_parallel"])
@@ -1059,23 +1066,37 @@ def create_engine_config(
10591066

10601067
# DP address, used in multi-node case for torch distributed group
10611068
# and ZMQ sockets.
1062-
data_parallel_address = self.data_parallel_address if (
1063-
self.data_parallel_address
1064-
is not None) else ParallelConfig.data_parallel_master_ip
1069+
if self.data_parallel_address is None:
1070+
if self.data_parallel_backend == "ray":
1071+
host_ip = get_ip()
1072+
logger.info(
1073+
"Using host IP %s as ray-based data parallel address",
1074+
host_ip)
1075+
data_parallel_address = host_ip
1076+
else:
1077+
assert self.data_parallel_backend == "mp", (
1078+
"data_parallel_backend can only be ray or mp, got %s",
1079+
self.data_parallel_backend)
1080+
data_parallel_address = ParallelConfig.data_parallel_master_ip
1081+
else:
1082+
data_parallel_address = self.data_parallel_address
10651083

10661084
# This port is only used when there are remote data parallel engines,
10671085
# otherwise the local IPC transport is used.
10681086
data_parallel_rpc_port = self.data_parallel_rpc_port if (
10691087
self.data_parallel_rpc_port
10701088
is not None) else ParallelConfig.data_parallel_rpc_port
10711089

1090+
data_parallel_backend = self.data_parallel_backend
1091+
10721092
parallel_config = ParallelConfig(
10731093
pipeline_parallel_size=self.pipeline_parallel_size,
10741094
tensor_parallel_size=self.tensor_parallel_size,
10751095
data_parallel_size=self.data_parallel_size,
10761096
data_parallel_size_local=data_parallel_size_local,
10771097
data_parallel_master_ip=data_parallel_address,
10781098
data_parallel_rpc_port=data_parallel_rpc_port,
1099+
data_parallel_backend=data_parallel_backend,
10791100
enable_expert_parallel=self.enable_expert_parallel,
10801101
max_parallel_loading_workers=self.max_parallel_loading_workers,
10811102
disable_custom_all_reduce=self.disable_custom_all_reduce,

vllm/entrypoints/cli/serve.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from vllm.v1.executor.abstract import Executor
2828
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
2929
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
30-
EngineZmqAddresses, get_engine_client_zmq_addr,
30+
CoreEngineActorManager, EngineZmqAddresses,
31+
get_engine_client_zmq_addr,
3132
wait_for_completion_or_failure,
3233
wait_for_engine_startup)
3334

@@ -229,6 +230,31 @@ def run_multi_api_server(args: argparse.Namespace):
229230
logger.info("Started DP Coordinator process (PID: %d)",
230231
coordinator.proc.pid)
231232

233+
if parallel_config.data_parallel_backend == "ray":
234+
logger.info("Starting ray-based data parallel backend")
235+
236+
engine_actor_manager = CoreEngineActorManager(
237+
vllm_config=vllm_config,
238+
addresses=addresses,
239+
executor_class=Executor.get_class(vllm_config),
240+
log_stats=not engine_args.disable_log_stats,
241+
)
242+
# Start API servers using the manager
243+
api_server_manager = APIServerProcessManager(
244+
target_server_fn=run_api_server_worker_proc,
245+
listen_address=listen_address,
246+
sock=sock,
247+
args=args,
248+
num_servers=num_api_servers,
249+
input_addresses=input_addresses,
250+
output_addresses=output_addresses,
251+
stats_update_address=stats_update_address)
252+
253+
wait_for_completion_or_failure(api_server_manager=api_server_manager,
254+
engine_manager=engine_actor_manager,
255+
coordinator=coordinator)
256+
return
257+
232258
handshake_address = get_engine_client_zmq_addr(
233259
local_only, host, parallel_config.data_parallel_rpc_port)
234260

@@ -277,10 +303,9 @@ def run_multi_api_server(args: argparse.Namespace):
277303
)
278304

279305
# Wait for API servers
280-
wait_for_completion_or_failure(
281-
api_server_manager=api_server_manager,
282-
local_engine_manager=local_engine_manager,
283-
coordinator=coordinator)
306+
wait_for_completion_or_failure(api_server_manager=api_server_manager,
307+
engine_manager=local_engine_manager,
308+
coordinator=coordinator)
284309

285310

286311
def run_api_server_worker_proc(listen_address,

vllm/v1/engine/async_llm.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from vllm.usage.usage_lib import UsageContext
2828
from vllm.utils import Device, cdiv
2929
from vllm.v1.engine import EngineCoreRequest
30-
from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient
30+
from vllm.v1.engine.core_client import (AsyncMPClient, DPAsyncMPClient,
31+
RayDPClient)
3132
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
3233
from vllm.v1.engine.output_processor import (OutputProcessor,
3334
RequestOutputCollector)
@@ -119,9 +120,13 @@ def __init__(
119120
log_stats=self.log_stats)
120121

121122
# EngineCore (starts the engine in background process).
122-
core_client_class = AsyncMPClient if (
123-
vllm_config.parallel_config.data_parallel_size
124-
== 1) else DPAsyncMPClient
123+
core_client_class: type[AsyncMPClient]
124+
if vllm_config.parallel_config.data_parallel_size == 1:
125+
core_client_class = AsyncMPClient
126+
elif vllm_config.parallel_config.data_parallel_backend == "ray":
127+
core_client_class = RayDPClient
128+
else:
129+
core_client_class = DPAsyncMPClient
125130

126131
self.engine_core = core_client_class(
127132
vllm_config=vllm_config,

0 commit comments

Comments
 (0)