Skip to content

Commit 0bf3002

Browse files
ilmarkovilmarkov
authored andcommitted
Fixes after rebase
Signed-off-by: ilmarkov <imarkov@redhat.com>
1 parent f3a267c commit 0bf3002

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import random
5+
import typing
6+
7+
import pytest
8+
import torch
9+
import torch.distributed as dist
10+
import torch.multiprocessing as mp
11+
12+
import vllm.envs as envs
13+
from vllm.distributed import cleanup_dist_env_and_memory
14+
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
15+
from vllm.distributed.device_communicators.cuda_communicator import (
16+
CudaCommunicator)
17+
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
18+
get_tp_group,
19+
init_distributed_environment,
20+
initialize_model_parallel)
21+
from vllm.platforms import current_platform
22+
from vllm.utils import update_environment_variables
23+
24+
torch.manual_seed(42)
25+
random.seed(44)
26+
27+
test_size_elements = 4 * 1024 * 1024
28+
29+
30+
def symm_mem_allreduce_worker(local_rank: int, world_size: int):
31+
monkeypatch = pytest.MonkeyPatch()
32+
with monkeypatch.context() as m:
33+
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
34+
dtype = torch.bfloat16
35+
device = torch.device(f"cuda:{local_rank}")
36+
torch.cuda.set_device(device)
37+
torch.set_default_device(device)
38+
torch.set_default_dtype(dtype)
39+
update_environment_variables({
40+
'RANK': str(local_rank),
41+
'LOCAL_RANK': str(local_rank),
42+
'WORLD_SIZE': str(world_size),
43+
'MASTER_ADDR': 'localhost',
44+
'MASTER_PORT': '12345',
45+
})
46+
47+
init_distributed_environment()
48+
initialize_model_parallel(tensor_model_parallel_size=world_size)
49+
50+
cuda_communicator = typing.cast(CudaCommunicator,
51+
get_tp_group().device_communicator)
52+
symm_mem_comm = cuda_communicator.symm_mem_comm
53+
if symm_mem_comm is None or symm_mem_comm.disabled:
54+
pytest.skip("SymmMemCommunicator is not available or disabled.")
55+
56+
inp_direct_symm_mem = torch.randint(1,
57+
23, (test_size_elements, ),
58+
dtype=dtype,
59+
device=device)
60+
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
61+
pytest.skip(
62+
"SymmMemCommunicator isn't used for this world and input size."
63+
)
64+
65+
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
66+
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
67+
assert out_direct_symm_mem is not None
68+
69+
group = get_tensor_model_parallel_group().device_group
70+
dist.all_reduce(original_inp_direct_symm_mem, group=group)
71+
torch.testing.assert_close(out_direct_symm_mem,
72+
original_inp_direct_symm_mem,
73+
atol=2.5,
74+
rtol=0.1)
75+
76+
# Test tensor_model_parallel_all_reduce which should use symm_mem
77+
inp_tensor_parallel = torch.randint(-23,
78+
1, (test_size_elements, ),
79+
dtype=dtype,
80+
device=device)
81+
original_inp_tensor_parallel = inp_tensor_parallel.clone()
82+
out_tensor_parallel = tensor_model_parallel_all_reduce(
83+
inp_tensor_parallel)
84+
dist.all_reduce(original_inp_tensor_parallel, group=group)
85+
torch.testing.assert_close(out_tensor_parallel,
86+
original_inp_tensor_parallel,
87+
atol=2.5,
88+
rtol=0.1)
89+
90+
91+
@pytest.mark.skipif(
92+
not current_platform.is_cuda(),
93+
reason="SymmMemAllreduce is only available for CUDA platforms.")
94+
@pytest.mark.parametrize("tp_size", [2])
95+
@pytest.mark.parametrize("pipeline_parallel_size", [1])
96+
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
97+
reason="Only test on CUDA")
98+
def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
99+
pipeline_parallel_size):
100+
world_size = tp_size * pipeline_parallel_size
101+
if world_size > torch.cuda.device_count():
102+
pytest.skip("Not enough GPUs to run the test.")
103+
104+
# Enable SymmMemCommunicator
105+
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")
106+
107+
mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size)
108+
cleanup_dist_env_and_memory()

0 commit comments

Comments
 (0)