Skip to content

Commit d8ed040

Browse files
ruisearch420xrushi
authored andcommitted
[DP][ray] Support different VLLM_RAY_DP_PACK_STRATEGY (vllm-project#23849)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
1 parent 6863306 commit d8ed040

File tree

2 files changed

+86
-33
lines changed

2 files changed

+86
-33
lines changed

vllm/envs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
VLLM_DP_MASTER_PORT: int = 0
134134
VLLM_MOE_DP_CHUNK_SIZE: int = 256
135135
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
136+
VLLM_RAY_DP_PACK_STRATEGY: str = "strict"
136137
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
137138
VLLM_MXFP4_USE_MARLIN: Optional[bool] = None
138139
VLLM_V0_USE_OUTLINES_CACHE: bool = False
@@ -1000,6 +1001,17 @@ def get_vllm_port() -> Optional[int]:
10001001
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0"
10011002
)
10021003
== "1",
1004+
# Strategy to pack the data parallel ranks for Ray.
1005+
# Available options:
1006+
# - "fill":
1007+
# for DP master node, allocate exactly data-parallel-size-local DP ranks,
1008+
# for non-master nodes, allocate as many DP ranks as can fit;
1009+
# - "strict":
1010+
# allocate exactly data-parallel-size-local DP ranks to each picked node;
1011+
# This environment variable is ignored if data-parallel-backend is not Ray.
1012+
"VLLM_RAY_DP_PACK_STRATEGY": lambda: os.getenv(
1013+
"VLLM_RAY_DP_PACK_STRATEGY", "strict"
1014+
),
10031015
# Whether to use S3 path for model loading in CI via RunAI Streamer
10041016
"VLLM_CI_USE_S3": lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",
10051017
# Use model_redirect to redirect the model name to a local folder.

vllm/v1/engine/utils.py

Lines changed: 74 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import msgspec
1616
import zmq
1717

18+
from vllm import envs
1819
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
1920
from vllm.logger import init_logger
2021
from vllm.platforms import current_platform
@@ -337,8 +338,8 @@ def create_dp_placement_groups(
337338

338339
logger.info("Creating placement groups for data parallel")
339340
dp_master_ip = vllm_config.parallel_config.data_parallel_master_ip
340-
num_pg_to_create = vllm_config.parallel_config.data_parallel_size
341-
local_engine_count = vllm_config.parallel_config.data_parallel_size_local
341+
dp_size = vllm_config.parallel_config.data_parallel_size
342+
dp_size_local = vllm_config.parallel_config.data_parallel_size_local
342343

343344
available_resources = available_resources_per_node()
344345
world_size = vllm_config.parallel_config.world_size
@@ -354,44 +355,84 @@ def create_dp_placement_groups(
354355
dp_master_ip,
355356
)
356357
device_str = current_platform.ray_device_key
358+
359+
if envs.VLLM_RAY_DP_PACK_STRATEGY == "fill" and (
360+
envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
361+
or envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"
362+
):
363+
raise ValueError(
364+
"DeepEP kernels require EP ranks [0,7] (same for [8,15], ...) "
365+
"to be on the same node, but VLLM_RAY_DP_PACK_STRATEGY=fill "
366+
"does not guarantee that. "
367+
"Please use VLLM_RAY_DP_PACK_STRATEGY=strict instead."
368+
)
369+
logger.info(
370+
"Using '%s' DP packing strategy based on VLLM_RAY_DP_PACK_STRATEGY",
371+
envs.VLLM_RAY_DP_PACK_STRATEGY,
372+
)
373+
strict_local_size = envs.VLLM_RAY_DP_PACK_STRATEGY == "strict"
374+
357375
for node_resources in nodes:
358-
if device_str not in node_resources:
359-
continue
376+
node_ip_keys = [
377+
key
378+
for key in node_resources
379+
if key != "node:__internal_head__" and key.startswith("node:")
380+
]
381+
assert len(node_ip_keys) == 1, (
382+
"Zero or multiple node IP keys found in node resources: %s",
383+
node_ip_keys,
384+
)
385+
node_ip_key = node_ip_keys[0]
386+
node_ip = node_ip_key.split(":")[1]
387+
360388
# For now, each DP rank can only be assigned to one node
361389
# TODO(rui): support allocating a single DP rank
362390
# to multiple nodes
363-
available_engine_count = int(node_resources[device_str]) // world_size
364-
if dp_master_ip_key in node_resources:
365-
assert available_engine_count >= local_engine_count, (
366-
"Not enough resources to allocate DP ranks "
367-
f"on DP master node {dp_master_ip}"
368-
)
369-
for i in range(local_engine_count):
370-
bundles = [
371-
{device_str: 1.0, "node:" + dp_master_ip: 0.001}
372-
] * world_size + [{"CPU": 1.0}]
373-
pg = ray.util.placement_group(
374-
name=f"dp_rank_{len(placement_groups)}",
375-
strategy="STRICT_PACK",
376-
bundles=bundles,
391+
dp_size_available = (
392+
int(node_resources[device_str]) // world_size
393+
if device_str in node_resources
394+
else 0
395+
)
396+
397+
if node_ip == dp_master_ip:
398+
if dp_size_available < dp_size_local:
399+
raise ValueError(
400+
"Not enough resources to allocate %s DP ranks "
401+
"on DP master node %s, possible to fit %s DP ranks",
402+
dp_size_local,
403+
dp_master_ip,
404+
dp_size_available,
377405
)
378-
placement_groups.append(pg)
379-
local_dp_ranks.append(i)
380-
else:
381-
for i in range(available_engine_count):
382-
if len(placement_groups) == num_pg_to_create:
383-
break
384-
bundles = [{device_str: 1.0}] * world_size + [{"CPU": 1.0}]
385-
pg = ray.util.placement_group(
386-
name=f"dp_rank_{len(placement_groups)}",
387-
strategy="STRICT_PACK",
388-
bundles=bundles,
406+
dp_size_to_allocate = dp_size_local
407+
elif strict_local_size:
408+
if dp_size_available < dp_size_local:
409+
logger.info(
410+
"Skipping node %s as %s DP ranks could not fit, "
411+
"possible to fit %s DP ranks",
412+
node_ip,
413+
dp_size_local,
414+
dp_size_available,
389415
)
390-
placement_groups.append(pg)
391-
local_dp_ranks.append(i)
392-
if len(placement_groups) < num_pg_to_create:
416+
continue
417+
dp_size_to_allocate = dp_size_local
418+
else:
419+
dp_size_to_allocate = dp_size_available
420+
421+
for i in range(dp_size_to_allocate):
422+
bundles = [{device_str: 1.0, "node:" + node_ip: 0.001}] * world_size + [
423+
{"CPU": 1.0}
424+
]
425+
pg = ray.util.placement_group(
426+
name=f"dp_rank_{len(placement_groups)}",
427+
strategy="STRICT_PACK",
428+
bundles=bundles,
429+
)
430+
placement_groups.append(pg)
431+
local_dp_ranks.append(i)
432+
433+
if len(placement_groups) < dp_size:
393434
raise ValueError(
394-
f"Not enough resources to allocate {num_pg_to_create} "
435+
f"Not enough resources to allocate {dp_size} "
395436
"placement groups, only created "
396437
f"{len(placement_groups)} placement groups. "
397438
"Available resources: "

0 commit comments

Comments
 (0)