Skip to content

Commit cbbb8a8

Browse files
authored
Merge pull request #4 from panpan0000/ucc_integration-pr-ut
2 parents a65a92f + a90900f commit cbbb8a8

File tree

1 file changed

+213
-0
lines changed

1 file changed

+213
-0
lines changed
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import random
5+
6+
import pytest
7+
import torch
8+
import torch.distributed as dist
9+
import torch.multiprocessing as mp
10+
11+
from vllm.distributed import cleanup_dist_env_and_memory
12+
from vllm.distributed.device_communicators.ucc_communicator import (
13+
UCCCommunicator)
14+
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
15+
init_distributed_environment,
16+
initialize_model_parallel)
17+
from vllm.platforms import current_platform
18+
from vllm.utils import update_environment_variables
19+
20+
torch.manual_seed(42)
21+
random.seed(44)
22+
23+
test_size_elements = 4 * 1024 * 1024
24+
25+
26+
def _select_device_and_dtype(local_rank: int):
27+
if current_platform.is_cuda():
28+
device = torch.device(f"cuda:{local_rank}")
29+
dtype = torch.bfloat16
30+
else:
31+
device = torch.device("cpu")
32+
dtype = torch.float32
33+
return device, dtype
34+
35+
36+
def ucc_allreduce_worker(local_rank: int, world_size: int):
37+
monkeypatch = pytest.MonkeyPatch()
38+
with monkeypatch.context() as m:
39+
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
40+
device, dtype = _select_device_and_dtype(local_rank)
41+
42+
# Set device only for CUDA
43+
if current_platform.is_cuda():
44+
torch.cuda.set_device(device)
45+
# set_default_device may not exist in all torch versions
46+
if hasattr(torch, "set_default_device"):
47+
torch.set_default_device(device)
48+
torch.set_default_dtype(dtype)
49+
50+
update_environment_variables({
51+
'RANK': str(local_rank),
52+
'LOCAL_RANK': str(local_rank),
53+
'WORLD_SIZE': str(world_size),
54+
'MASTER_ADDR': 'localhost',
55+
'MASTER_PORT': '12345',
56+
})
57+
58+
init_distributed_environment()
59+
initialize_model_parallel(tensor_model_parallel_size=world_size)
60+
61+
# Check if UCC is available
62+
if not UCCCommunicator.is_ucc_available():
63+
pytest.skip("UCC backend is not available in PyTorch.")
64+
65+
# Create reference device group from TP group
66+
group = get_tensor_model_parallel_group(
67+
).device_group # pyright: ignore[reportDeprecated]
68+
69+
# Try to create a UCC process group
70+
try:
71+
ucc_group = dist.new_group(backend="ucc")
72+
except Exception:
73+
pytest.skip("Failed to create UCC process group.")
74+
75+
# Initialize UCC communicator
76+
ucc_communicator = UCCCommunicator(group=ucc_group, device=device)
77+
78+
if ucc_communicator.disabled:
79+
pytest.skip("UCCCommunicator is disabled.")
80+
81+
# Test direct UCC allreduce
82+
inp_direct_ucc = torch.randint(1,
83+
23, (test_size_elements, ),
84+
dtype=dtype,
85+
device=device)
86+
87+
if not ucc_communicator.should_use_ucc_allreduce(inp_direct_ucc):
88+
pytest.skip(
89+
"UCCCommunicator isn't used for this world size and input size."
90+
)
91+
92+
original_inp_direct_ucc = inp_direct_ucc.clone()
93+
out_direct_ucc = ucc_communicator.all_reduce(inp_direct_ucc)
94+
assert out_direct_ucc is not None
95+
96+
# Compare with regular allreduce
97+
dist.all_reduce(original_inp_direct_ucc, group=group)
98+
99+
# Tolerance based on dtype
100+
if dtype == torch.float32:
101+
atol, rtol = 1e-3, 1e-4
102+
else:
103+
atol, rtol = 2.5, 0.1
104+
torch.testing.assert_close(out_direct_ucc,
105+
original_inp_direct_ucc,
106+
atol=atol,
107+
rtol=rtol)
108+
109+
# Test different reduction operations
110+
for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX, dist.ReduceOp.MIN]:
111+
inp_op_test = torch.randint(1,
112+
10, (1024, ),
113+
dtype=dtype,
114+
device=device)
115+
original_inp_op_test = inp_op_test.clone()
116+
117+
out_ucc_op = ucc_communicator.all_reduce(inp_op_test, op=op)
118+
if out_ucc_op is not None:
119+
dist.all_reduce(original_inp_op_test, op=op, group=group)
120+
torch.testing.assert_close(out_ucc_op,
121+
original_inp_op_test,
122+
atol=atol,
123+
rtol=rtol)
124+
125+
# Test tensor size threshold (avoid huge allocation by using meta)
126+
small_tensor = torch.ones(100, dtype=dtype, device=device)
127+
large_tensor = torch.empty(513 * 1024 * 1024,
128+
dtype=torch.uint8,
129+
device='meta') # > 512MB, meta device
130+
131+
assert ucc_communicator.should_use_ucc_allreduce(small_tensor) is True
132+
assert ucc_communicator.should_use_ucc_allreduce(large_tensor) is False
133+
134+
# Test device mismatch handling
135+
cpu_tensor = torch.ones(100, dtype=dtype, device="cpu")
136+
out_cpu = ucc_communicator.all_reduce(cpu_tensor)
137+
if out_cpu is not None:
138+
assert out_cpu.device == device
139+
140+
141+
def ucc_availability_worker(local_rank: int, world_size: int):
142+
"""Test UCC availability detection"""
143+
monkeypatch = pytest.MonkeyPatch()
144+
with monkeypatch.context() as m:
145+
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
146+
device, _ = _select_device_and_dtype(local_rank)
147+
if current_platform.is_cuda():
148+
torch.cuda.set_device(device)
149+
150+
update_environment_variables({
151+
'RANK': str(local_rank),
152+
'LOCAL_RANK': str(local_rank),
153+
'WORLD_SIZE': str(world_size),
154+
'MASTER_ADDR': 'localhost',
155+
'MASTER_PORT': '12347',
156+
})
157+
158+
init_distributed_environment()
159+
initialize_model_parallel(tensor_model_parallel_size=world_size)
160+
161+
# Test static method
162+
is_available = UCCCommunicator.is_ucc_available()
163+
assert isinstance(is_available, bool)
164+
165+
if not is_available:
166+
pytest.skip("UCC backend is not available in PyTorch.")
167+
168+
# Test with non-UCC group (should disable communicator)
169+
gloo_group = dist.new_group(backend="gloo")
170+
ucc_comm_with_gloo = UCCCommunicator(group=gloo_group, device=device)
171+
assert ucc_comm_with_gloo.disabled is True
172+
173+
174+
@pytest.mark.parametrize("tp_size", [2])
175+
@pytest.mark.parametrize("pipeline_parallel_size", [1])
176+
def test_ucc_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
177+
pipeline_parallel_size):
178+
world_size = tp_size * pipeline_parallel_size
179+
180+
# For CUDA, ensure enough GPUs; for CPU, proceed.
181+
if current_platform.is_cuda() and world_size > torch.cuda.device_count():
182+
pytest.skip("Not enough GPUs to run the test.")
183+
184+
mp.spawn(ucc_allreduce_worker, args=(world_size, ), nprocs=world_size)
185+
cleanup_dist_env_and_memory()
186+
187+
188+
@pytest.mark.parametrize("tp_size", [2])
189+
@pytest.mark.parametrize("pipeline_parallel_size", [1])
190+
def test_ucc_availability(monkeypatch: pytest.MonkeyPatch, tp_size,
191+
pipeline_parallel_size):
192+
world_size = tp_size * pipeline_parallel_size
193+
194+
if current_platform.is_cuda() and world_size > torch.cuda.device_count():
195+
pytest.skip("Not enough GPUs to run the test.")
196+
197+
mp.spawn(ucc_availability_worker, args=(world_size, ), nprocs=world_size)
198+
cleanup_dist_env_and_memory()
199+
200+
201+
def test_ucc_communicator_initialization():
202+
"""Basic check that static availability method works."""
203+
is_available = UCCCommunicator.is_ucc_available()
204+
assert isinstance(is_available, bool)
205+
206+
207+
def test_ucc_static_methods():
208+
"""Test static methods of UCCCommunicator"""
209+
# Test is_ucc_available static method
210+
is_available = UCCCommunicator.is_ucc_available()
211+
assert isinstance(is_available, bool)
212+
# The method should not crash regardless of environment
213+
# and should return a boolean value

0 commit comments

Comments
 (0)