Skip to content

Commit 26eaecc

Browse files
committed
[DP][ray] Support different VLLM_RAY_DP_PACK_STRATEGY
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
1 parent b5ee1e3 commit 26eaecc

File tree

2 files changed

+82
-39
lines changed

2 files changed

+82
-39
lines changed

vllm/envs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
VLLM_DP_MASTER_PORT: int = 0
124124
VLLM_MOE_DP_CHUNK_SIZE: int = 256
125125
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
126+
VLLM_RAY_DP_PACK_STRATEGY: str = None
126127
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
127128
VLLM_MXFP4_USE_MARLIN: Optional[bool] = None
128129
VLLM_V0_USE_OUTLINES_CACHE: bool = False
@@ -913,6 +914,17 @@ def get_vllm_port() -> Optional[int]:
913914
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS":
914915
lambda: os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1",
915916

917+
# Strategy to pack the data parallel ranks for Ray.
918+
# Available options:
919+
# - "fill":
920+
# for DP master node, allocate exactly data-parallel-size-local DP ranks,
921+
# for non-master nodes, allocate as many DP ranks as can fit;
922+
# - "strict":
923+
# allocate exactly data-parallel-size-local DP ranks to each picked node;
924+
# This environment variable is ignored if data-parallel-backend is not Ray.
925+
"VLLM_RAY_DP_PACK_STRATEGY":
926+
lambda: os.getenv("VLLM_RAY_DP_PACK_STRATEGY", "fill"),
927+
916928
# Whether to use S3 path for model loading in CI via RunAI Streamer
917929
"VLLM_CI_USE_S3":
918930
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",

vllm/v1/engine/utils.py

Lines changed: 70 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import msgspec
1616
import zmq
1717

18+
from vllm import envs
1819
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
1920
from vllm.logger import init_logger
2021
from vllm.platforms import current_platform
@@ -319,8 +320,8 @@ def create_dp_placement_groups(
319320
logger.info("Creating placement groups for data parallel")
320321
dp_master_ip = \
321322
vllm_config.parallel_config.data_parallel_master_ip
322-
num_pg_to_create = vllm_config.parallel_config.data_parallel_size
323-
local_engine_count = \
323+
dp_size = vllm_config.parallel_config.data_parallel_size
324+
dp_size_local = \
324325
vllm_config.parallel_config.data_parallel_size_local
325326

326327
available_resources = available_resources_per_node()
@@ -334,50 +335,80 @@ def create_dp_placement_groups(
334335
"No nodes with resources found in Ray cluster.")
335336
assert dp_master_ip_key in nodes[0], (
336337
"The DP master node (ip: %s) is missing or dead", dp_master_ip)
338+
339+
if envs.VLLM_RAY_DP_PACK_STRATEGY == "strict":
340+
logger.info(
341+
"Using strict local size packing strategy based "
342+
"on VLLM_RAY_DP_PACK_STRATEGY (%s)",
343+
envs.VLLM_RAY_DP_PACK_STRATEGY)
344+
strict_local_size = True
345+
elif (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
346+
or envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"):
347+
logger.info(
348+
"Using strict local size packing strategy based "
349+
"on VLLM_ALL2ALL_BACKEND (%s)", envs.VLLM_ALL2ALL_BACKEND)
350+
strict_local_size = True
351+
else:
352+
logger.info(
353+
"Using fill packing strategy based "
354+
"on VLLM_RAY_DP_PACK_STRATEGY (%s)",
355+
envs.VLLM_RAY_DP_PACK_STRATEGY)
356+
strict_local_size = False
357+
337358
for node_resources in nodes:
338359
if "GPU" not in node_resources:
339360
continue
340361
# For now, each DP rank can only be assigned to one node
341362
# TODO(rui): support allocating a single DP rank
342363
# to multiple nodes
343-
available_engine_count = int(node_resources["GPU"]) // world_size
344-
if dp_master_ip_key in node_resources:
345-
assert available_engine_count >= local_engine_count, (
346-
"Not enough resources to allocate DP ranks "
347-
f"on DP master node {dp_master_ip}")
348-
for i in range(local_engine_count):
349-
bundles = [{
350-
"GPU": 1.0,
351-
"node:" + dp_master_ip: 0.001
352-
}] * world_size + [{
353-
"CPU": 1.0
354-
}]
355-
pg = ray.util.placement_group(
356-
name=f"dp_rank_{len(placement_groups)}",
357-
strategy="STRICT_PACK",
358-
bundles=bundles,
359-
)
360-
placement_groups.append(pg)
361-
local_dp_ranks.append(i)
364+
node_ip_keys = [
365+
key for key in node_resources if key.startswith('node:')
366+
]
367+
assert len(node_ip_keys) == 1, (
368+
"Zero or multiple node IP keys found in node resources: %s",
369+
node_ip_keys)
370+
node_ip_key = node_ip_keys[0]
371+
node_ip = node_ip_key.split(":")[1]
372+
dp_size_available = int(node_resources["GPU"]) // world_size
373+
if strict_local_size:
374+
if dp_size_available < dp_size_local:
375+
if node_ip == dp_master_ip:
376+
raise ValueError(
377+
"Not enough resources to allocate DP ranks "
378+
f"on DP master node {dp_master_ip}")
379+
else:
380+
logger.info(
381+
"Skipping node %s as %s DP ranks could not fit, "
382+
"possible to fit %s DP ranks", node_ip,
383+
dp_size_local, dp_size_available)
384+
continue
385+
dp_size_to_allocate = dp_size_local
386+
elif node_ip == dp_master_ip:
387+
dp_size_to_allocate = dp_size_local
362388
else:
363-
for i in range(available_engine_count):
364-
if len(placement_groups) == num_pg_to_create:
365-
break
366-
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
367-
pg = ray.util.placement_group(
368-
name=f"dp_rank_{len(placement_groups)}",
369-
strategy="STRICT_PACK",
370-
bundles=bundles,
371-
)
372-
placement_groups.append(pg)
373-
local_dp_ranks.append(i)
374-
if len(placement_groups) < num_pg_to_create:
375-
raise ValueError(
376-
f"Not enough resources to allocate {num_pg_to_create} "
377-
"placement groups, only created "
378-
f"{len(placement_groups)} placement groups. "
379-
"Available resources: "
380-
f"{available_resources}")
389+
dp_size_to_allocate = dp_size_available
390+
391+
for i in range(dp_size_to_allocate):
392+
bundles = [{
393+
"GPU": 1.0,
394+
"node:" + node_ip: 0.001
395+
}] * world_size + [{
396+
"CPU": 1.0
397+
}]
398+
pg = ray.util.placement_group(
399+
name=f"dp_rank_{len(placement_groups)}",
400+
strategy="STRICT_PACK",
401+
bundles=bundles,
402+
)
403+
placement_groups.append(pg)
404+
local_dp_ranks.append(i)
405+
406+
if len(placement_groups) < dp_size:
407+
raise ValueError(f"Not enough resources to allocate {dp_size} "
408+
"placement groups, only created "
409+
f"{len(placement_groups)} placement groups. "
410+
"Available resources: "
411+
f"{available_resources}")
381412
return placement_groups, local_dp_ranks
382413

383414
@staticmethod

0 commit comments

Comments
 (0)