Skip to content

Commit 9c9a7cd

Browse files
authored
[main] adapt usage of npu_moe_gating_top_k_softmax and remove envs.SELECT_GATING_TOPK_SOTFMAX_EXPERTS (#2112)
backport of v0.9.1-dev: #1902 origin main npu_moe_gating_top_k_softmax: #1355 - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@055bd39 Signed-off-by: huangxialu <huangxialu1@huawei.com>
1 parent e8660d7 commit 9c9a7cd

File tree

5 files changed

+146
-89
lines changed

5 files changed

+146
-89
lines changed

tests/e2e/singlecard/ops/test_fused_moe.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
# here to make the test pass.
2424
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
2525

26+
from unittest.mock import MagicMock, patch
27+
2628
import pytest
2729
import torch
2830
from vllm.model_executor.layers.activation import SiluAndMul
2931

30-
from vllm_ascend.ops.fused_moe import fused_experts
32+
from vllm_ascend.ops.fused_moe import fused_experts, select_experts
3133

3234
NUM_EXPERTS = [8, 64]
3335
EP_SIZE = [1, 4]
@@ -98,3 +100,97 @@ def test_fused_experts(
98100
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
99101
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
100102
torch.npu.empty_cache()
103+
104+
105+
@pytest.mark.parametrize("m", [1, 33, 64])
106+
@pytest.mark.parametrize("n", [128, 1024, 2048])
107+
@pytest.mark.parametrize("e", NUM_EXPERTS)
108+
@pytest.mark.parametrize("topk", TOP_KS)
109+
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
110+
@pytest.mark.parametrize("use_grouped_topk", [True, False])
111+
@pytest.mark.parametrize("renormalize", [True, False])
112+
@pytest.mark.parametrize("with_e_correction", [True, False])
113+
@pytest.mark.parametrize("custom_routing", [True, False])
114+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
115+
@pytest.mark.parametrize("device", DEVICE)
116+
def test_select_experts(
117+
m: int,
118+
n: int,
119+
e: int,
120+
topk: int,
121+
scoring_func: str,
122+
use_grouped_topk: bool,
123+
renormalize: bool,
124+
with_e_correction: bool,
125+
custom_routing: bool,
126+
dtype: torch.dtype,
127+
device: str,
128+
):
129+
topk_group = 4 if use_grouped_topk else None
130+
num_expert_group = e // 4 if use_grouped_topk else None
131+
132+
hidden_states = torch.randn(m, n, device=device, dtype=dtype)
133+
router_logits = torch.randn(m, e, device=device, dtype=dtype)
134+
135+
e_score_correction_bias = (torch.randn(e, device=device, dtype=dtype)
136+
if with_e_correction else None)
137+
138+
custom_routing_function = None
139+
if custom_routing:
140+
custom_routing_function = MagicMock()
141+
mock_weights = torch.randn(m, topk, device=device, dtype=dtype)
142+
mock_ids = torch.randint(0,
143+
e, (m, topk),
144+
device=device,
145+
dtype=torch.int32)
146+
custom_routing_function.return_value = (mock_weights, mock_ids)
147+
148+
with patch("vllm_ascend.ops.fused_moe.native_grouped_topk"
149+
) as mock_native_grouped_topk:
150+
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
151+
x)
152+
153+
topk_weights, topk_ids = select_experts(
154+
hidden_states=hidden_states,
155+
router_logits=router_logits,
156+
top_k=topk,
157+
use_grouped_topk=use_grouped_topk,
158+
renormalize=renormalize,
159+
topk_group=topk_group,
160+
num_expert_group=num_expert_group,
161+
custom_routing_function=custom_routing_function,
162+
scoring_func=scoring_func,
163+
e_score_correction_bias=e_score_correction_bias,
164+
)
165+
166+
if use_grouped_topk:
167+
mock_native_grouped_topk.assert_called_once()
168+
else:
169+
mock_native_grouped_topk.assert_not_called()
170+
171+
assert topk_weights.shape == (m, topk)
172+
assert topk_ids.shape == (m, topk)
173+
assert topk_ids.dtype == torch.int32
174+
175+
176+
@pytest.mark.parametrize("device", DEVICE)
177+
def test_select_experts_invalid_scoring_func(device: str):
178+
with pytest.raises(ValueError,
179+
match="Unsupported scoring function: invalid"):
180+
select_experts(hidden_states=torch.randn(1, 128, device=device),
181+
router_logits=torch.randn(1, 8, device=device),
182+
top_k=2,
183+
use_grouped_topk=False,
184+
renormalize=False,
185+
scoring_func="invalid")
186+
187+
188+
@pytest.mark.parametrize("device", DEVICE)
189+
def test_select_experts_missing_group_params(device: str):
190+
with pytest.raises(AssertionError):
191+
select_experts(hidden_states=torch.randn(1, 128, device=device),
192+
router_logits=torch.randn(1, 64, device=device),
193+
top_k=2,
194+
use_grouped_topk=True,
195+
renormalize=False,
196+
scoring_func="softmax")

tests/ut/ops/test_fused_ops.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,8 @@ def test_process_weights_after_loading(self, moe_method, mock_dist_env):
297297
assert not layer.w13_weight.requires_grad
298298
assert not layer.w2_weight.requires_grad
299299

300-
@pytest.mark.parametrize(
301-
"others_param",
302-
[[256, 4, False], [128, 1, False], [128, 1, True], [128, 4, False]])
300+
@pytest.mark.parametrize("others_param",
301+
[[256, 4], [128, 1], [128, 1], [128, 4]])
303302
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
304303
mock_moe_env, others_param):
305304
"""
@@ -308,15 +307,13 @@ def test_apply_without_expert_map(self, moe_method, mock_dist_env,
308307
3 test use select_gating_topk_softmax_experts and fused_experts
309308
4 test use select_experts and fused_experts_with_all2all_buffer
310309
"""
311-
global_num_experts, ep_size, select_softmax = others_param
310+
global_num_experts, ep_size = others_param
312311
is_prefill = False
313312
is_deepseek_v3_r1 = global_num_experts == 256
314313
forward_context = MagicMock(fused_moe_state=get_fused_moe_state(
315314
ep_size, is_prefill, is_deepseek_v3_r1))
316-
with patch(
317-
"vllm_ascend.ops.fused_moe.SELECT_GATING_TOPK_SOTFMAX_EXPERTS",
318-
select_softmax), \
319-
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context):
315+
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
316+
return_value=forward_context):
320317
moe_method.ep_size = ep_size
321318
x = torch.randn(8, 2, 2)
322319
router_logits = torch.randn(8, 8)

vllm_ascend/envs.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,6 @@
117117
# value to False to disable the optimized model.
118118
"USE_OPTIMIZED_MODEL":
119119
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
120-
# SELECT_GATING_TOPK_SOTFMAX_EXPERTS is the equivalent of select_experts in non-quantized scenarios.
121-
# In theory, it should have better performance than select_experts.
122-
# Subsequent versions will remove the SELECT_GATING_TOPK_SOTFMAX_EXPERTS tag and use it as the default mode.
123-
"SELECT_GATING_TOPK_SOTFMAX_EXPERTS":
124-
lambda: bool(int(os.getenv("SELECT_GATING_TOPK_SOTFMAX_EXPERTS", '0'))),
125120
# The tolerance of the kv cache size, if the difference between the
126121
# actual kv cache size and the cached kv cache size is less than this value,
127122
# then the cached kv cache size will be used.

vllm_ascend/ops/common_fused_moe.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,10 @@
2222
from vllm.model_executor.layers.fused_moe.layer import \
2323
UnquantizedFusedMoEMethod
2424

25-
import vllm_ascend.envs as envs_ascend
2625
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
27-
select_experts,
28-
select_gating_top_k_softmax_experts)
26+
select_experts)
2927
from vllm_ascend.utils import is_310p
3028

31-
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
3229
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
3330

3431

@@ -61,26 +58,19 @@ def forward_oot(
6158
logical_to_physical_map: Optional[torch.Tensor] = None,
6259
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
6360

64-
if SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
65-
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
66-
hidden_states=x,
67-
router_logits=router_logits,
68-
top_k=top_k,
69-
renormalize=renormalize)
70-
else:
71-
topk_weights, topk_ids = select_experts(
72-
global_num_experts=global_num_experts,
73-
hidden_states=x,
74-
router_logits=router_logits,
75-
top_k=top_k,
76-
use_grouped_topk=use_grouped_topk,
77-
renormalize=renormalize,
78-
topk_group=topk_group,
79-
num_expert_group=num_expert_group,
80-
custom_routing_function=custom_routing_function,
81-
scoring_func=scoring_func,
82-
e_score_correction_bias=e_score_correction_bias,
83-
)
61+
topk_weights, topk_ids = select_experts(
62+
global_num_experts=global_num_experts,
63+
hidden_states=x,
64+
router_logits=router_logits,
65+
top_k=top_k,
66+
use_grouped_topk=use_grouped_topk,
67+
renormalize=renormalize,
68+
topk_group=topk_group,
69+
num_expert_group=num_expert_group,
70+
custom_routing_function=custom_routing_function,
71+
scoring_func=scoring_func,
72+
e_score_correction_bias=e_score_correction_bias,
73+
)
8474

8575
if topk_ids.shape[1] < top_k or is_310p():
8676
assert global_num_experts is not None

vllm_ascend/ops/fused_moe.py

Lines changed: 30 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
get_rm_router_logits_state, is_310p)
5353

5454
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
55-
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
5655

5756

5857
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
@@ -859,39 +858,6 @@ def fused_experts(
859858
return final_hidden_states
860859

861860

862-
def select_gating_top_k_softmax_experts(
863-
hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int,
864-
renormalize: bool) -> tuple[torch.Tensor, torch.Tensor]:
865-
"""
866-
Select top-k experts based on router logits.
867-
only supports float16、bfloat16、float32
868-
869-
Args:
870-
hidden_states: Hidden states of shape (num_tokens, hidden_size).
871-
router_logits: Router logits of shape (num_tokens, num_experts).
872-
top_k: Number of experts to select.
873-
renormalize: Whether to renormalize the routing weights.
874-
875-
Returns:
876-
topk_weights: Routing weights of shape (num_tokens, top_k).
877-
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
878-
879-
Raises:
880-
ValueError: If an unsupported scoring function is provided.
881-
"""
882-
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
883-
router_logits, None, k=top_k)
884-
885-
# # Required by npu_moe_init_routing
886-
# topk_weights = topk_weights.to(hidden_states.dtype)
887-
# topk_ids = topk_ids.to(torch.int32)
888-
889-
if renormalize:
890-
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
891-
892-
return topk_weights, topk_ids
893-
894-
895861
def native_grouped_topk(
896862
topk_weights: torch.Tensor,
897863
num_expert_group: Optional[int],
@@ -953,8 +919,24 @@ def select_experts(
953919
ValueError: If an unsupported scoring function is provided.
954920
"""
955921

922+
def _renormalize_topk_weights(
923+
topk_weights: torch.Tensor,
924+
renormalize: bool,
925+
):
926+
if renormalize:
927+
topk_weights = topk_weights / topk_weights.sum(dim=-1,
928+
keepdim=True)
929+
return topk_weights
930+
956931
if scoring_func == "softmax":
957932
# NOTE: vLLM use dtype=torch.float here
933+
if not use_grouped_topk and custom_routing_function is None:
934+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
935+
x=router_logits, finished=None, k=top_k)
936+
topk_ids = topk_ids.to(torch.int32)
937+
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
938+
return topk_weights, topk_ids
939+
958940
topk_weights = router_logits.softmax(dim=-1)
959941
elif scoring_func == "sigmoid":
960942
topk_weights = router_logits.sigmoid()
@@ -988,10 +970,11 @@ def select_experts(
988970
k=top_k,
989971
dim=-1,
990972
sorted=False)
991-
elif custom_routing_function is None:
992-
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
993-
topk_weights = topk_weights.to(hidden_states.dtype)
994-
else:
973+
topk_ids = topk_ids.to(torch.int32)
974+
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
975+
return topk_weights, topk_ids
976+
977+
if custom_routing_function is not None:
995978
topk_weights, topk_ids = custom_routing_function(
996979
hidden_states=hidden_states,
997980
gating_output=router_logits,
@@ -1002,11 +985,12 @@ def select_experts(
1002985
topk_ids = topk_ids.to(torch.int32)
1003986
return topk_weights, topk_ids
1004987

988+
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
989+
topk_weights = topk_weights.to(hidden_states.dtype)
990+
1005991
# Required by npu_moe_init_routing
1006992
topk_ids = topk_ids.to(torch.int32)
1007-
1008-
if renormalize:
1009-
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
993+
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
1010994

1011995
return topk_weights, topk_ids
1012996

@@ -1070,23 +1054,18 @@ def apply(
10701054
if is_deepseek_v3_r1:
10711055
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
10721056
router_logits,
1073-
k=top_k, # topk当前写8
1057+
k=top_k, # topk currently is 8
10741058
bias=e_score_correction_bias,
10751059
k_group=topk_group, # fix: 4
10761060
group_count=num_expert_group, # fix 8
1077-
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
1061+
group_select_mode=
1062+
1, # 0: the maximum in the group; 1: topk2.sum(fix)
10781063
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
10791064
norm_type=1, # 0: softmax; 1: sigmoid(fix)
1080-
# out_flag=False, # todo new api; 第三个输出是否输出
1081-
# y2_flag=False, # old api; 第三个输出是否输出
1065+
# out_flag=False, # todo new api; should the third output be output
1066+
# y2_flag=False, # old api; should the third output be output
10821067
routed_scaling_factor=1,
10831068
eps=float(1e-20))
1084-
elif SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
1085-
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
1086-
hidden_states=x,
1087-
router_logits=router_logits,
1088-
top_k=top_k,
1089-
renormalize=renormalize)
10901069
else:
10911070
topk_weights, topk_ids = select_experts(
10921071
hidden_states=x,

0 commit comments

Comments
 (0)