Skip to content

Commit 1d86f50

Browse files
committed
[DP][ray] Support different VLLM_RAY_DP_PACK_STRATEGY
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
1 parent 31a4b3e commit 1d86f50

File tree

2 files changed

+88
-52
lines changed

2 files changed

+88
-52
lines changed

vllm/envs.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
VLLM_DP_MASTER_PORT: int = 0
134134
VLLM_MOE_DP_CHUNK_SIZE: int = 256
135135
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
136+
VLLM_RAY_DP_PACK_STRATEGY: str = "fill"
136137
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
137138
VLLM_MXFP4_USE_MARLIN: Optional[bool] = None
138139
VLLM_V0_USE_OUTLINES_CACHE: bool = False
@@ -997,10 +998,19 @@ def get_vllm_port() -> Optional[int]:
997998
# units.
998999
"VLLM_MOE_DP_CHUNK_SIZE": lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")),
9991000
# Randomize inputs during dummy runs when using Data Parallel
1000-
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS": lambda: os.environ.get(
1001-
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0"
1002-
)
1003-
== "1",
1001+
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS":
1002+
lambda: os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1",
1003+
1004+
# Strategy to pack the data parallel ranks for Ray.
1005+
# Available options:
1006+
# - "fill":
1007+
# for DP master node, allocate exactly data-parallel-size-local DP ranks,
1008+
# for non-master nodes, allocate as many DP ranks as can fit;
1009+
# - "strict":
1010+
# allocate exactly data-parallel-size-local DP ranks to each picked node;
1011+
# This environment variable is ignored if data-parallel-backend is not Ray.
1012+
"VLLM_RAY_DP_PACK_STRATEGY":
1013+
lambda: os.getenv("VLLM_RAY_DP_PACK_STRATEGY", "fill"),
10041014
# Whether to use S3 path for model loading in CI via RunAI Streamer
10051015
"VLLM_CI_USE_S3": lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",
10061016
# Use model_redirect to redirect the model name to a local folder.

vllm/v1/engine/utils.py

Lines changed: 74 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import msgspec
1616
import zmq
1717

18+
from vllm import envs
1819
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
1920
from vllm.logger import init_logger
2021
from vllm.platforms import current_platform
@@ -335,9 +336,11 @@ def create_dp_placement_groups(
335336
from ray._private.state import available_resources_per_node
336337

337338
logger.info("Creating placement groups for data parallel")
338-
dp_master_ip = vllm_config.parallel_config.data_parallel_master_ip
339-
num_pg_to_create = vllm_config.parallel_config.data_parallel_size
340-
local_engine_count = vllm_config.parallel_config.data_parallel_size_local
339+
dp_master_ip = \
340+
vllm_config.parallel_config.data_parallel_master_ip
341+
dp_size = vllm_config.parallel_config.data_parallel_size
342+
dp_size_local = \
343+
vllm_config.parallel_config.data_parallel_size_local
341344

342345
available_resources = available_resources_per_node()
343346
world_size = vllm_config.parallel_config.world_size
@@ -349,53 +352,76 @@ def create_dp_placement_groups(
349352
)
350353
assert len(nodes) > 0, "No nodes with resources found in Ray cluster."
351354
assert dp_master_ip_key in nodes[0], (
352-
"The DP master node (ip: %s) is missing or dead",
353-
dp_master_ip,
354-
)
355-
device_str = current_platform.ray_device_key
355+
"The DP master node (ip: %s) is missing or dead", dp_master_ip)
356+
357+
if envs.VLLM_RAY_DP_PACK_STRATEGY == "strict":
358+
logger.info(
359+
"Using strict local size packing strategy based "
360+
"on VLLM_RAY_DP_PACK_STRATEGY (%s)",
361+
envs.VLLM_RAY_DP_PACK_STRATEGY)
362+
strict_local_size = True
363+
elif (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
364+
or envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"):
365+
logger.info(
366+
"Using strict local size packing strategy based "
367+
"on VLLM_ALL2ALL_BACKEND (%s)", envs.VLLM_ALL2ALL_BACKEND)
368+
strict_local_size = True
369+
else:
370+
logger.info(
371+
"Using fill packing strategy based "
372+
"on VLLM_RAY_DP_PACK_STRATEGY (%s)",
373+
envs.VLLM_RAY_DP_PACK_STRATEGY)
374+
strict_local_size = False
356375
for node_resources in nodes:
357-
if device_str not in node_resources:
358-
continue
359-
# For now, each DP rank can only be assigned to one node
360-
# TODO(rui): support allocating a single DP rank
361-
# to multiple nodes
362-
available_engine_count = int(node_resources[device_str]) // world_size
363-
if dp_master_ip_key in node_resources:
364-
assert available_engine_count >= local_engine_count, (
365-
"Not enough resources to allocate DP ranks "
366-
f"on DP master node {dp_master_ip}"
367-
)
368-
for i in range(local_engine_count):
369-
bundles = [
370-
{device_str: 1.0, "node:" + dp_master_ip: 0.001}
371-
] * world_size + [{"CPU": 1.0}]
372-
pg = ray.util.placement_group(
373-
name=f"dp_rank_{len(placement_groups)}",
374-
strategy="STRICT_PACK",
375-
bundles=bundles,
376-
)
377-
placement_groups.append(pg)
378-
local_dp_ranks.append(i)
376+
node_ip_keys = [
377+
key for key in node_resources if key.startswith('node:')
378+
]
379+
assert len(node_ip_keys) == 1, (
380+
"Zero or multiple node IP keys found in node resources: %s",
381+
node_ip_keys)
382+
node_ip_key = node_ip_keys[0]
383+
node_ip = node_ip_key.split(":")[1]
384+
dp_size_available = int(node_resources["GPU"]) // world_size
385+
if strict_local_size:
386+
if dp_size_available < dp_size_local:
387+
if node_ip == dp_master_ip:
388+
raise ValueError(
389+
"Not enough resources to allocate %s DP ranks "
390+
"on DP master node %s, possible to fit %s DP ranks",
391+
dp_size_local, dp_master_ip, dp_size_available)
392+
else:
393+
logger.info(
394+
"Skipping node %s as %s DP ranks could not fit, "
395+
"possible to fit %s DP ranks", node_ip,
396+
dp_size_local, dp_size_available)
397+
continue
398+
dp_size_to_allocate = dp_size_local
399+
elif node_ip == dp_master_ip:
400+
dp_size_to_allocate = dp_size_local
379401
else:
380-
for i in range(available_engine_count):
381-
if len(placement_groups) == num_pg_to_create:
382-
break
383-
bundles = [{device_str: 1.0}] * world_size + [{"CPU": 1.0}]
384-
pg = ray.util.placement_group(
385-
name=f"dp_rank_{len(placement_groups)}",
386-
strategy="STRICT_PACK",
387-
bundles=bundles,
388-
)
389-
placement_groups.append(pg)
390-
local_dp_ranks.append(i)
391-
if len(placement_groups) < num_pg_to_create:
392-
raise ValueError(
393-
f"Not enough resources to allocate {num_pg_to_create} "
394-
"placement groups, only created "
395-
f"{len(placement_groups)} placement groups. "
396-
"Available resources: "
397-
f"{available_resources}"
398-
)
402+
dp_size_to_allocate = dp_size_available
403+
404+
for i in range(dp_size_to_allocate):
405+
bundles = [{
406+
"GPU": 1.0,
407+
"node:" + node_ip: 0.001
408+
}] * world_size + [{
409+
"CPU": 1.0
410+
}]
411+
pg = ray.util.placement_group(
412+
name=f"dp_rank_{len(placement_groups)}",
413+
strategy="STRICT_PACK",
414+
bundles=bundles,
415+
)
416+
placement_groups.append(pg)
417+
local_dp_ranks.append(i)
418+
419+
if len(placement_groups) < dp_size:
420+
raise ValueError(f"Not enough resources to allocate {dp_size} "
421+
"placement groups, only created "
422+
f"{len(placement_groups)} placement groups. "
423+
"Available resources: "
424+
f"{available_resources}")
399425
return placement_groups, local_dp_ranks
400426

401427
@staticmethod

0 commit comments

Comments
 (0)