Skip to content

Commit 37855db

Browse files
patrickvonplatengemini-code-assist[bot]ruisearch42
authored andcommitted
[Data-parallel] Allow DP>1 for world_size > num_gpus on node (8) (vllm-project#26367)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Signed-off-by: Rui Qiao <ruisearch42@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Rui Qiao <ruisearch42@gmail.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
1 parent aa8507a commit 37855db

File tree

4 files changed

+96
-22
lines changed

4 files changed

+96
-22
lines changed

docs/serving/data_parallel_deployment.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ There are several notable differences when using Ray:
6969
- A single launch command (on any node) is needed to start all local and remote DP ranks, therefore it is more convenient compared to launching on each node
7070
- There is no need to specify `--data-parallel-address`, and the node where the command is run is used as `--data-parallel-address`
7171
- There is no need to specify `--data-parallel-rpc-port`
72+
- When a single DP group requires multiple nodes, *e.g.* in case a single model replica needs to run on at least two nodes, make sure to set `VLLM_RAY_DP_PACK_STRATEGY="span"` in which case `--data-parallel-size-local` is ignored and will be automatically determined
7273
- Remote DP ranks will be allocated based on node resources of the Ray cluster
7374

7475
Currently, the internal DP load balancing is done within the API server process(es) and is based on the running and waiting queues in each of the engines. This could be made more sophisticated in future by incorporating KV cache aware logic.

vllm/engine/arg_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1403,8 +1403,15 @@ def create_engine_config(
14031403
"data_parallel_size_local must be set to use data_parallel_hybrid_lb."
14041404
)
14051405

1406-
# Local DP size defaults to global DP size if not set.
1407-
data_parallel_size_local = self.data_parallel_size
1406+
if self.data_parallel_backend == "ray" and (
1407+
envs.VLLM_RAY_DP_PACK_STRATEGY == "span"
1408+
):
1409+
# Data parallel size defaults to 1 if DP ranks are spanning
1410+
# multiple nodes
1411+
data_parallel_size_local = 1
1412+
else:
1413+
# Otherwise local DP size defaults to global DP size if not set
1414+
data_parallel_size_local = self.data_parallel_size
14081415

14091416
# DP address, used in multi-node case for torch distributed group
14101417
# and ZMQ sockets.

vllm/envs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@
139139
VLLM_DP_MASTER_PORT: int = 0
140140
VLLM_MOE_DP_CHUNK_SIZE: int = 256
141141
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
142-
VLLM_RAY_DP_PACK_STRATEGY: str = "strict"
142+
VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict"
143143
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
144144
VLLM_MXFP4_USE_MARLIN: bool | None = None
145145
VLLM_V0_USE_OUTLINES_CACHE: bool = False
@@ -1039,6 +1039,9 @@ def get_vllm_port() -> int | None:
10391039
# for non-master nodes, allocate as many DP ranks as can fit;
10401040
# - "strict":
10411041
# allocate exactly data-parallel-size-local DP ranks to each picked node;
1042+
# - "span":
1043+
# Should be used only when a single DP rank requires multiple nodes.
1044+
# allocate one DP rank over as many nodes as required for set world_size;
10421045
# This environment variable is ignored if data-parallel-backend is not Ray.
10431046
"VLLM_RAY_DP_PACK_STRATEGY": lambda: os.getenv(
10441047
"VLLM_RAY_DP_PACK_STRATEGY", "strict"

vllm/v1/engine/utils.py

Lines changed: 82 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ def create_dp_placement_groups(
345345
world_size = vllm_config.parallel_config.world_size
346346
placement_groups: list[PlacementGroup] = []
347347
local_dp_ranks: list[int] = []
348+
348349
dp_master_ip_key = f"node:{dp_master_ip}"
349350
nodes = sorted(
350351
available_resources.values(), key=lambda x: dp_master_ip_key not in x
@@ -355,9 +356,25 @@ def create_dp_placement_groups(
355356
dp_master_ip,
356357
)
357358
device_str = current_platform.ray_device_key
359+
n_node_devices: list[int] = [
360+
int(node_resources[device_str])
361+
for node_resources in nodes
362+
if device_str in node_resources
363+
]
364+
assert n_node_devices, f"No {device_str} found in Ray cluster."
365+
max_device_per_node = max(n_node_devices)
366+
367+
pack_strategy = envs.VLLM_RAY_DP_PACK_STRATEGY
368+
_supported_pack_strategies = ("strict", "fill", "span")
369+
if pack_strategy not in _supported_pack_strategies:
370+
raise ValueError(
371+
f"{envs.VLLM_RAY_DP_PACK_STRATEGY} is not supported. "
372+
"Make sure to set `VLLM_RAY_DP_PACK_STRATEGY` "
373+
f"to one of {_supported_pack_strategies}"
374+
)
358375

359376
all2all_backend = vllm_config.parallel_config.all2all_backend
360-
if envs.VLLM_RAY_DP_PACK_STRATEGY == "fill" and (
377+
if pack_strategy == "fill" and (
361378
all2all_backend == "deepep_high_throughput"
362379
or all2all_backend == "deepep_low_latency"
363380
):
@@ -367,12 +384,42 @@ def create_dp_placement_groups(
367384
"does not guarantee that. "
368385
"Please use VLLM_RAY_DP_PACK_STRATEGY=strict instead."
369386
)
370-
logger.info(
371-
"Using '%s' DP packing strategy based on VLLM_RAY_DP_PACK_STRATEGY",
372-
envs.VLLM_RAY_DP_PACK_STRATEGY,
373-
)
374-
strict_local_size = envs.VLLM_RAY_DP_PACK_STRATEGY == "strict"
375387

388+
if pack_strategy in ("strict", "fill"):
389+
placement_strategy = "STRICT_PACK"
390+
else:
391+
placement_strategy = "PACK"
392+
assert world_size > max_device_per_node, (
393+
f"World size {world_size} is smaller than the "
394+
"maximum number of devices per node "
395+
f"{max_device_per_node}. Make sure to set "
396+
"`VLLM_RAY_DP_PACK_STRATEGY` to `strict` or `fill`"
397+
)
398+
399+
# if we need multiple nodes per dp group, we require for now that
400+
# available nodes are homogenous
401+
assert set(n_node_devices) == {max_device_per_node}, (
402+
f"Nodes are not homogenous, {nodes}"
403+
)
404+
assert world_size % max_device_per_node == 0, (
405+
f"For multi-node data parallel groups, world_size ({world_size}) must "
406+
f"be a multiple of number of devices per node ({max_device_per_node})."
407+
)
408+
assert len(n_node_devices) * max_device_per_node >= world_size * dp_size, (
409+
f"Not enough total available nodes ({len(n_node_devices)}) "
410+
f"and devices per node ({max_device_per_node}) "
411+
f"to satisfy required world size {world_size} and data parallel size "
412+
f"{dp_size}"
413+
)
414+
assert dp_size_local == 1, (
415+
f"data-parallel-size-local {dp_size_local} should be set as the "
416+
"default (1) for VLLM_RAY_DP_PACK_STRATEGY=span. "
417+
"The actual data-parallel-size-local will be auto determined."
418+
)
419+
420+
# bundles collected for a single DP rank from multiple nodes,
421+
# for "span" pack strategy
422+
collected_bundles = []
376423
for node_resources in nodes:
377424
node_ip_keys = [
378425
key
@@ -386,14 +433,14 @@ def create_dp_placement_groups(
386433
node_ip_key = node_ip_keys[0]
387434
node_ip = node_ip_key.split(":")[1]
388435

389-
# For now, each DP rank can only be assigned to one node
390-
# TODO(rui): support allocating a single DP rank
391-
# to multiple nodes
392-
dp_size_available = (
393-
int(node_resources[device_str]) // world_size
394-
if device_str in node_resources
395-
else 0
396-
)
436+
n_device_on_node = int(node_resources.get(device_str, 0))
437+
if pack_strategy == "span" and n_device_on_node != 0:
438+
# Strictly speaking,
439+
# dp_size_available = n_device_on_node / world_size
440+
# and is a fraction, but we use 1 for easier processing
441+
dp_size_available = 1
442+
else:
443+
dp_size_available = n_device_on_node // world_size
397444

398445
if node_ip == dp_master_ip:
399446
if dp_size_available < dp_size_local:
@@ -405,7 +452,7 @@ def create_dp_placement_groups(
405452
dp_size_available,
406453
)
407454
dp_size_to_allocate = dp_size_local
408-
elif strict_local_size:
455+
elif pack_strategy == "strict":
409456
if dp_size_available < dp_size_local:
410457
logger.info(
411458
"Skipping node %s as %s DP ranks could not fit, "
@@ -417,15 +464,31 @@ def create_dp_placement_groups(
417464
continue
418465
dp_size_to_allocate = dp_size_local
419466
else:
467+
# for "pack_strategy" in "fill" and "span"
468+
# we always take everything that's available
420469
dp_size_to_allocate = dp_size_available
421470

422471
for i in range(dp_size_to_allocate):
423-
bundles = [{device_str: 1.0, "node:" + node_ip: 0.001}] * world_size + [
424-
{"CPU": 1.0}
425-
]
472+
device_bundle = [{device_str: 1.0, "node:" + node_ip: 0.001}]
473+
if pack_strategy == "span":
474+
collected_bundles += device_bundle * n_device_on_node
475+
assert len(collected_bundles) <= world_size, (
476+
"collected_bundles should be <= world_size, "
477+
f"but got {len(collected_bundles)=} and {world_size=}"
478+
)
479+
480+
# we only create a placement group if we collected enough devices
481+
if len(collected_bundles) < world_size:
482+
continue
483+
484+
bundles = collected_bundles + [{"CPU": 1.0}]
485+
collected_bundles = []
486+
else:
487+
bundles = device_bundle * world_size + [{"CPU": 1.0}]
488+
426489
pg = ray.util.placement_group(
427490
name=f"dp_rank_{len(placement_groups)}",
428-
strategy="STRICT_PACK",
491+
strategy=placement_strategy,
429492
bundles=bundles,
430493
)
431494
placement_groups.append(pg)

0 commit comments

Comments
 (0)