Skip to content

Commit 6e00aed

Browse files
weijinqian0weijinqian_v1whx-sjtuIrving11-BKNPotabk
authored
[main][Feature]Moe alltoallv communication optimization for unquantized RL training sence (#2088)
It comes from 0.9.1dev [0.9.1][Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo (#1547) - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@97608dc --------- Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: curryliu <120010041@link.cuhk.edu.cn> Signed-off-by: wangli <wangli858794774@gmail.com> Signed-off-by: ChenTaoyu-SJTU <ctynb@qq.com> Signed-off-by: taoxudonghaha <justsheldon@163.com> Signed-off-by: shen-shanshan <467638484@qq.com> Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com> Signed-off-by: leo-pony <nengjunma@outlook.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com> Co-authored-by: curryliu <99582471+Irving11-BKN@users.noreply.github.com> Co-authored-by: Li Wang <wangli858794774@gmail.com> Co-authored-by: TaoYu Chen <ctynb@qq.com> Co-authored-by: taoxudonghaha <justsheldon@163.com> Co-authored-by: Shanshan Shen <467638484@qq.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
1 parent f0c1f0c commit 6e00aed

File tree

14 files changed

+1265
-17
lines changed

14 files changed

+1265
-17
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ jobs:
278278
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
279279
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo
280280
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
281+
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv
281282
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
282283
pytest -sv tests/e2e/multicard/test_data_parallel.py
283284
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ ray>=2.47.1
1717
protobuf>3.20.0
1818
librosa
1919
soundfile
20+
pytest_mock

tests/e2e/multicard/test_offline_inference_distributed.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,28 @@ def test_models_distributed_topk() -> None:
157157
vllm_model.generate(example_prompts, sampling_params)
158158

159159

160+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1"})
161+
def test_models_distributed_alltoallv() -> None:
162+
example_prompts = [
163+
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
164+
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
165+
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
166+
]
167+
dtype = "half"
168+
sampling_params = SamplingParams(max_tokens=5,
169+
temperature=0.0,
170+
top_k=50,
171+
top_p=0.9)
172+
173+
with VllmRunner(
174+
"deepseek-ai/DeepSeek-V2-Lite",
175+
dtype=dtype,
176+
tensor_parallel_size=2,
177+
distributed_executor_backend="mp",
178+
) as vllm_model:
179+
vllm_model.generate(example_prompts, sampling_params)
180+
181+
160182
def test_models_distributed_Qwen3_W8A8():
161183
example_prompts = [
162184
"Hello, my name is",
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
18+
import importlib
19+
20+
import pytest
21+
import torch
22+
from pytest_mock import MockerFixture
23+
24+
from tests.ut.base import PytestBase
25+
from vllm_ascend.distributed.tensor_parallel import (
26+
_gather_along_first_dim, _gather_along_last_dim,
27+
_reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim,
28+
all_to_all_hp2sp, all_to_all_sp2hp)
29+
30+
31+
class TestDistributedCommunication(PytestBase):
32+
33+
@pytest.fixture(autouse=True)
34+
def context(self, mocker: MockerFixture):
35+
mocker.patch("torch.npu.current_device", return_value="cpu")
36+
mocker.patch("torch.distributed.get_world_size", return_value=4)
37+
38+
mocker.patch("torch.distributed.get_rank", return_value=0)
39+
40+
@pytest.mark.parametrize("world_size, test_tensor, expected",
41+
[(1, torch.randn(8, 16), (8, 16)),
42+
(4, torch.randn(8, 16), (32, 16))])
43+
def test_gather_along_first_dim(self, test_tensor, expected, world_size,
44+
mocker: MockerFixture):
45+
"""test _gather_along_first_dim"""
46+
mocker.patch("torch.distributed.get_world_size",
47+
return_value=world_size)
48+
49+
result = _gather_along_first_dim(test_tensor, mocker.MagicMock())
50+
51+
assert result.shape == expected
52+
53+
@pytest.mark.parametrize("test_tensor, output_split_sizes, expected", [
54+
(torch.randn(8, 16), [5, 10, 15, 2], (32, 16)),
55+
])
56+
def test_gather_along_first_dim_unequal_split(self, test_tensor, expected,
57+
output_split_sizes,
58+
mocker: MockerFixture):
59+
"""test _gather_along_first_dim"""
60+
61+
result = _gather_along_first_dim(test_tensor, mocker.MagicMock(),
62+
output_split_sizes)
63+
64+
assert result.shape == expected
65+
66+
@pytest.mark.parametrize("world_size, test_tensor, expected",
67+
[(1, torch.randn(8, 16, 32), (8, 16, 32)),
68+
(4, torch.randn(8, 16, 32), (8, 16, 32 * 4))])
69+
def test_gather_along_last_dim(self, test_tensor, expected, world_size,
70+
mocker: MockerFixture):
71+
"""test _gather_along_last_dim"""
72+
mocker.patch("torch.distributed.get_world_size",
73+
return_value=world_size)
74+
75+
result = _gather_along_last_dim(test_tensor, mocker.MagicMock())
76+
77+
assert result.shape == expected
78+
79+
@pytest.mark.parametrize("input_shape,expected_shape", [
80+
((32, 16), (8, 16)),
81+
((40, 10), (10, 10)),
82+
])
83+
def test_reduce_scatter_along_first_dim(self, input_shape, expected_shape,
84+
mocker: MockerFixture):
85+
input_tensor = torch.randn(*input_shape)
86+
result = _reduce_scatter_along_first_dim(input_tensor,
87+
mocker.MagicMock())
88+
assert result.shape == expected_shape
89+
90+
@pytest.mark.parametrize("input_shape,expected_shape", [
91+
((8, 16, 32), (8, 16, 8)),
92+
])
93+
def test_reduce_scatter_along_last_dim(self, input_shape, expected_shape,
94+
mocker: MockerFixture):
95+
input_tensor = torch.randn(*input_shape)
96+
result = _reduce_scatter_along_last_dim(input_tensor,
97+
mocker.MagicMock())
98+
assert result.shape == expected_shape
99+
100+
@pytest.mark.parametrize("func,input_shape,expected_shape", [
101+
("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32),
102+
(8, 16, 128)),
103+
("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)),
104+
("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32),
105+
(8, 16, 8)),
106+
("gather_from_sequence_parallel_region", (8, 16), (32, 16)),
107+
])
108+
def test_wrapper_functions(self, func, input_shape, expected_shape,
109+
mocker: MockerFixture):
110+
"""test wrapper funcs"""
111+
mod = importlib.import_module(
112+
'vllm_ascend.distributed.tensor_parallel')
113+
globals = mod.__dict__
114+
test_func = globals[func]
115+
input_tensor = torch.randn(*input_shape)
116+
result = test_func(input_tensor, mocker.MagicMock())
117+
assert result.shape == expected_shape
118+
119+
@pytest.mark.parametrize(
120+
"input_shape,output_shape",
121+
[
122+
((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP]
123+
])
124+
def test_all_to_all_sp2hp(self, input_shape, output_shape,
125+
mocker: MockerFixture):
126+
input_tensor = torch.randn(*input_shape)
127+
result = all_to_all_sp2hp(input_tensor, mocker.MagicMock())
128+
assert result.shape == output_shape
129+
130+
@pytest.mark.parametrize(
131+
"input_shape,output_shape",
132+
[
133+
((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H]
134+
])
135+
def test_all_to_all_hp2sp(self, input_shape, output_shape,
136+
mocker: MockerFixture):
137+
input_tensor = torch.randn(*input_shape)
138+
result = all_to_all_hp2sp(input_tensor, mocker.MagicMock())
139+
assert result.shape == output_shape
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
18+
import pytest
19+
from pytest_mock import MockerFixture
20+
21+
from tests.ut.base import PytestBase
22+
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
23+
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
24+
from vllm_ascend.utils import adapt_patch # noqa E402
25+
26+
27+
class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase):
28+
29+
@pytest.fixture
30+
def config(self):
31+
config = MoEDispatcherConfig()
32+
config.set_num_local_experts(2)
33+
config.set_num_moe_experts(4)
34+
config.set_moe_pad_expert_input_to_capacity(False)
35+
config.set_moe_expert_capacity_factor(None)
36+
config.set_moe_router_topk(2)
37+
config.set_moe_grouped_gemm(False)
38+
config.set_group_topk(0)
39+
config.set_num_groups(1)
40+
config.set_is_fused(False)
41+
return config.build()
42+
43+
def mock_ep_group(self, mocker):
44+
mock_group = mocker.MagicMock()
45+
mock_group.rank_in_group = 0
46+
mock_group.world_size = 2
47+
mock_group.device_group = "mock_group"
48+
return mock_group
49+
50+
@pytest.fixture
51+
def dispatcher(self, config, mocker: MockerFixture):
52+
mocker.patch(
53+
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group",
54+
return_value=self.mock_ep_group(mocker))
55+
mocker.patch("torch.npu.current_device", return_value="cpu")
56+
mocker.patch("torch.npu.Stream", return_value=mocker.MagicMock)
57+
return MoEAlltoAllSeqOverLapDispatcher(config)
58+
59+
def test_initialization(self, dispatcher, config):
60+
assert dispatcher.num_local_experts == config.num_local_experts
61+
assert dispatcher.num_experts == config.num_moe_experts
62+
assert dispatcher.local_expert_indices == [0, 1]
63+
assert dispatcher.ep_rank == 0
64+
assert dispatcher.ep_size == 2
65+
assert dispatcher.overlap_stream is not None

vllm_ascend/ascend_forward_context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class FusedMoEState(Enum):
1818
MC2 = 2
1919
AllGatherEP = 3
2020
NaiveMulticast = 4
21+
All2AllSeq = 5
2122

2223

2324
# TODO(zzzzwwjj): add soc_version to choose branch
@@ -33,6 +34,10 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool,
3334
return FusedMoEState.NaiveMulticast
3435
else:
3536
return FusedMoEState.AllGather
37+
elif envs.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ:
38+
# MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage.
39+
return (FusedMoEState.All2AllSeq if
40+
(ep_size < 16 or with_prefill) else FusedMoEState.MC2)
3641
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
3742
elif ep_size < 16 or with_prefill:
3843
return FusedMoEState.All2All

0 commit comments

Comments
 (0)