Skip to content

Commit 562a001

Browse files
committed
apply npu_moe_gating_top_k_softmax
Signed-off-by: huangxialu <huangxialu1@huawei.com>
1 parent 4014ad2 commit 562a001

File tree

2 files changed

+127
-1
lines changed

2 files changed

+127
-1
lines changed

tests/singlecard/ops/test_fused_moe.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,15 @@
2323
# here to make the test pass.
2424
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
2525

26+
from unittest.mock import MagicMock, patch
27+
2628
import pytest
2729
import torch
2830
from vllm.model_executor.layers.activation import SiluAndMul
2931

30-
from vllm_ascend.ops.fused_moe import fused_experts
32+
from vllm_ascend.ascend_forward_context import FusedMoEState
33+
from vllm_ascend.ops.fused_moe import (AscendUnquantizedFusedMoEMethod,
34+
fused_experts)
3135

3236
NUM_EXPERTS = [8, 64]
3337
EP_SIZE = [1, 4]
@@ -98,3 +102,119 @@ def test_fused_experts(
98102
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
99103
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
100104
torch.npu.empty_cache()
105+
106+
107+
@pytest.mark.parametrize("m", [1, 33, 64])
108+
@pytest.mark.parametrize("n", [128, 1024, 2048])
109+
@pytest.mark.parametrize("k", [128, 511])
110+
@pytest.mark.parametrize("e", NUM_EXPERTS)
111+
@pytest.mark.parametrize("topk", TOP_KS)
112+
@pytest.mark.parametrize("renormalize", [True, False])
113+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
114+
@pytest.mark.parametrize("device", DEVICE)
115+
def test_ascend_unquantized_fused_moe_softmax(
116+
m: int,
117+
n: int,
118+
k: int,
119+
e: int,
120+
topk: int,
121+
renormalize: bool,
122+
dtype: torch.dtype,
123+
device: str,
124+
):
125+
126+
class MockVllmConfig:
127+
128+
@property
129+
def scheduler_config(self):
130+
131+
class SchedulerConfig:
132+
max_num_seqs = 256
133+
134+
return SchedulerConfig()
135+
136+
@property
137+
def model_config(self):
138+
139+
class ModelConfig:
140+
max_model_len = 2048
141+
142+
return ModelConfig()
143+
144+
class MockAscendConfig:
145+
146+
@property
147+
def torchair_graph_config(self):
148+
149+
class TorchairGraphConfig:
150+
enabled = False
151+
152+
return TorchairGraphConfig()
153+
154+
class MockMC2Group:
155+
156+
@property
157+
def device_group(self):
158+
return MagicMock()
159+
160+
class MockForwardContext:
161+
162+
@property
163+
def fused_moe_state(self):
164+
return FusedMoEState.AllGather
165+
166+
class MockLayer(torch.nn.Module):
167+
168+
def __init__(self):
169+
super().__init__()
170+
self.w13_weight = torch.randn(
171+
(e, 2 * n, k), device=device, dtype=dtype) / 10
172+
self.w2_weight = torch.randn(
173+
(e, k, n), device=device, dtype=dtype) / 10
174+
175+
x = torch.randn((m, k), device=device, dtype=dtype) / 10
176+
router_logits = torch.randn((m, e), device=device, dtype=dtype)
177+
178+
with patch('vllm_ascend.ops.fused_moe.get_current_vllm_config') as mock_get_vllm_config, \
179+
patch('vllm_ascend.ops.fused_moe.get_ascend_config') as mock_get_ascend_config, \
180+
patch('vllm_ascend.ops.fused_moe.get_mc2_group') as mock_get_mc2_group, \
181+
patch('vllm_ascend.ops.fused_moe.get_forward_context') as mock_get_context, \
182+
patch('vllm_ascend.ops.fused_moe.fused_experts') as mock_fused_experts, \
183+
patch('torch.distributed.get_rank') as mock_get_rank:
184+
mock_get_vllm_config.return_value = MockVllmConfig()
185+
mock_get_ascend_config.return_value = MockAscendConfig()
186+
mock_get_mc2_group.return_value = MockMC2Group()
187+
mock_get_context.return_value = MockForwardContext()
188+
mock_fused_experts.return_value = torch.zeros_like(x)
189+
190+
mock_get_rank.side_effect = AttributeError("mock error")
191+
192+
method = AscendUnquantizedFusedMoEMethod()
193+
layer = MockLayer()
194+
195+
output = method.apply(
196+
layer=layer,
197+
x=x,
198+
router_logits=router_logits,
199+
top_k=topk,
200+
renormalize=renormalize,
201+
scoring_func="softmax",
202+
global_num_experts=e,
203+
)
204+
205+
assert method.moe_all_to_all_group_name is None
206+
207+
assert mock_fused_experts.called
208+
call_args = mock_fused_experts.call_args[1]
209+
210+
topk_weights = call_args['topk_weights']
211+
topk_ids = call_args['topk_ids']
212+
213+
assert topk_weights.shape == (m, topk)
214+
assert topk_ids.shape == (m, topk)
215+
assert output.shape == (m, k)
216+
assert topk_weights.dtype == dtype
217+
assert topk_ids.dtype == torch.int32
218+
assert output.dtype == dtype
219+
220+
torch.npu.empty_cache()

vllm_ascend/ops/fused_moe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,12 @@ def apply(
10051005
routed_scaling_factor=1,
10061006
eps=float(1e-20),
10071007
)
1008+
elif scoring_func == "softmax":
1009+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
1010+
x=router_logits, finished=None, k=top_k)
1011+
if renormalize:
1012+
topk_weights = topk_weights / topk_weights.sum(dim=-1,
1013+
keepdim=True)
10081014
else:
10091015
topk_weights, topk_ids = select_experts(
10101016
hidden_states=x,

0 commit comments

Comments
 (0)