Skip to content

Commit c40692b

Browse files
authored
[Misc] Add parallel state node_count function (#20045)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 4734704 commit c40692b

File tree

3 files changed

+98
-2
lines changed

3 files changed

+98
-2
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,11 +619,13 @@ steps:
619619
commands:
620620
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
621621
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
622+
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
622623
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
623624
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
624625
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
625626
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
626627
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
628+
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
627629
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
628630

629631
- label: Distributed Tests (2 GPUs) # 40min
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import os
5+
6+
import torch.distributed as dist
7+
8+
from vllm.distributed.parallel_state import _node_count
9+
from vllm.distributed.utils import StatelessProcessGroup
10+
from vllm.utils import get_ip, get_open_port
11+
12+
if __name__ == "__main__":
13+
dist.init_process_group(backend="gloo")
14+
15+
rank = dist.get_rank()
16+
world_size = dist.get_world_size()
17+
18+
if rank == 0:
19+
port = get_open_port()
20+
ip = get_ip()
21+
dist.broadcast_object_list([ip, port], src=0)
22+
else:
23+
recv = [None, None]
24+
dist.broadcast_object_list(recv, src=0)
25+
ip, port = recv
26+
27+
stateless_pg = StatelessProcessGroup.create(ip, port, rank, world_size)
28+
29+
for pg in [dist.group.WORLD, stateless_pg]:
30+
test_result = _node_count(pg)
31+
32+
# Expected node count based on environment variable)
33+
expected = int(os.environ.get("NUM_NODES", "1"))
34+
35+
assert test_result == expected, \
36+
f"Expected {expected} nodes, got {test_result}"
37+
38+
if pg == dist.group.WORLD:
39+
print(f"Node count test passed! Got {test_result} nodes "
40+
f"when using torch distributed!")
41+
else:
42+
print(f"Node count test passed! Got {test_result} nodes "
43+
f"when using StatelessProcessGroup!")

vllm/distributed/parallel_state.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,7 @@ def combine(self, hidden_states) -> torch.Tensor:
802802

803803

804804
_WORLD: Optional[GroupCoordinator] = None
805+
_NODE_COUNT: Optional[int] = None
805806

806807

807808
def get_world_group() -> GroupCoordinator:
@@ -961,10 +962,13 @@ def init_distributed_environment(
961962
local_rank = envs.LOCAL_RANK
962963
else:
963964
local_rank = rank
964-
global _WORLD
965+
global _WORLD, _NODE_COUNT
965966
if _WORLD is None:
966967
ranks = list(range(torch.distributed.get_world_size()))
967968
_WORLD = init_world_group(ranks, local_rank, backend)
969+
_NODE_COUNT = _node_count(_WORLD.cpu_group)
970+
logger.debug("Detected %d nodes in the distributed environment",
971+
_NODE_COUNT)
968972
else:
969973
assert _WORLD.world_size == torch.distributed.get_world_size(), (
970974
"world group already initialized with a different world size")
@@ -1164,6 +1168,13 @@ def get_tensor_model_parallel_rank():
11641168
return get_tp_group().rank_in_group
11651169

11661170

1171+
def get_node_count() -> int:
1172+
"""Return the total number of nodes in the distributed environment. """
1173+
assert _NODE_COUNT is not None, (
1174+
"distributed environment is not initialized")
1175+
return _NODE_COUNT
1176+
1177+
11671178
def destroy_model_parallel():
11681179
"""Set the groups to none and destroy them."""
11691180
global _TP
@@ -1189,10 +1200,11 @@ def destroy_model_parallel():
11891200

11901201

11911202
def destroy_distributed_environment():
1192-
global _WORLD
1203+
global _WORLD, _NODE_COUNT
11931204
if _WORLD:
11941205
_WORLD.destroy()
11951206
_WORLD = None
1207+
_NODE_COUNT = None
11961208
if torch.distributed.is_initialized():
11971209
torch.distributed.destroy_process_group()
11981210

@@ -1301,3 +1313,42 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
13011313
aggregated_data += rank_data
13021314

13031315
return [x == 1 for x in aggregated_data.tolist()]
1316+
1317+
1318+
def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
1319+
"""
1320+
Returns the total number of nodes in the process group.
1321+
1322+
Args:
1323+
pg: The process group to analyze
1324+
1325+
Returns:
1326+
int: The total number of nodes
1327+
"""
1328+
if isinstance(pg, ProcessGroup):
1329+
world_size = torch.distributed.get_world_size(group=pg)
1330+
else:
1331+
world_size = pg.world_size
1332+
1333+
if world_size == 1:
1334+
return 1
1335+
1336+
# Build node assignment map
1337+
node_assignment = [0] * world_size # rank -> node_id
1338+
next_node_id = 0
1339+
1340+
for current_rank in range(world_size):
1341+
if node_assignment[current_rank] != 0:
1342+
continue # Already assigned to a node
1343+
1344+
# Assign current rank to a new node
1345+
next_node_id += 1
1346+
node_assignment[current_rank] = next_node_id
1347+
1348+
# Find all ranks on the same node as current_rank
1349+
same_node_flags = in_the_same_node_as(pg, current_rank)
1350+
for other_rank, is_same_node in enumerate(same_node_flags):
1351+
if is_same_node and node_assignment[other_rank] == 0:
1352+
node_assignment[other_rank] = next_node_id
1353+
1354+
return next_node_id

0 commit comments

Comments
 (0)