Skip to content

Commit bc3fb6e

Browse files
yangcheng (AJ)yangcheng
authored andcommitted
refactor
Signed-off-by: yangcheng <yangcheng104@huawei.com>
1 parent 1a70564 commit bc3fb6e

File tree

10 files changed

+359
-370
lines changed

10 files changed

+359
-370
lines changed

tests/e2e/singlecard/ops/test_fused_moe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
import torch
2727
from vllm.model_executor.layers.activation import SiluAndMul
2828

29-
from vllm_ascend.ops.fused_moe import fused_experts, select_experts
29+
from vllm_ascend.ops.fused_moe import fused_experts
30+
from vllm_ascend.ops.layers.experts_selector import select_experts
3031

3132
NUM_EXPERTS = [8, 64]
3233
EP_SIZE = [1, 4]
@@ -142,7 +143,7 @@ def test_select_experts(
142143
dtype=torch.int32)
143144
custom_routing_function.return_value = (mock_weights, mock_ids)
144145

145-
with patch("vllm_ascend.ops.fused_moe.native_grouped_topk"
146+
with patch("vllm_ascend.ops.layers.experts_selector._native_grouped_topk"
146147
) as mock_native_grouped_topk:
147148
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
148149
x)

tests/ut/ops/test_fused_ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
2626
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
2727
AscendUnquantizedFusedMoEMethod)
28+
from vllm_ascend.ops.layers.experts_selector import select_experts
2829
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
2930

3031
adapt_patch(True)
@@ -389,3 +390,28 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
389390
assert result.shape == (16, 2)
390391
else:
391392
assert result.shape == x.shape
393+
394+
395+
class TestExpertsSelector:
396+
397+
@pytest.mark.parametrize("global_num_experts", [[256], [128]])
398+
def test_select_experts(self, mock_dist_env, mock_moe_env,
399+
global_num_experts):
400+
401+
x = torch.randn(8, 2)
402+
router_logits = torch.randn(8, 2)
403+
topk_weights, topk_ids = select_experts(
404+
hidden_states=x,
405+
router_logits=router_logits,
406+
top_k=2,
407+
use_grouped_topk=False,
408+
renormalize=True,
409+
topk_group=None,
410+
num_expert_group=None,
411+
custom_routing_function=None,
412+
scoring_func="softmax",
413+
e_score_correction_bias=None,
414+
global_num_experts=global_num_experts)
415+
416+
assert topk_weights.shape == (8, 2)
417+
assert topk_ids.shape == (8, 2)

tests/ut/quantization/test_w8a8.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66
from tests.ut.base import TestBase
77
from vllm_ascend.attention.attention_v1 import AscendAttentionState
8+
from vllm_ascend.ops.layers.experts_selector import (_native_grouped_topk,
9+
select_experts)
810
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
911
AscendW8A8FusedMoEMethod,
1012
AscendW8A8LinearMethod,
1113
fused_experts, fused_experts_310p,
12-
native_grouped_topk,
13-
quant_per_tensor, select_experts)
14+
quant_per_tensor)
1415

1516

1617
class TestQuantPerTensor(TestBase):
@@ -772,7 +773,7 @@ def test_grouped_topk(self, mock_topk):
772773
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
773774
self.assertEqual(ids.dtype, torch.int32)
774775

775-
@patch('vllm_ascend.quantization.w8a8.native_grouped_topk')
776+
@patch('vllm_ascend.ops.layers.experts_selector._native_grouped_topk')
776777
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
777778
"""Test grouped topk with expert score correction bias"""
778779
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
@@ -868,9 +869,9 @@ def test_basic_group_selection(self):
868869

869870
with patch('torch.topk',
870871
return_value=(None, expected_topk_indices)) as mock_topk:
871-
result = native_grouped_topk(topk_weights=topk_weights,
872-
num_expert_group=2,
873-
topk_group=2)
872+
result = _native_grouped_topk(topk_weights=topk_weights,
873+
num_expert_group=2,
874+
topk_group=2)
874875

875876
mock_topk.assert_called_once()
876877

@@ -885,9 +886,9 @@ def test_partial_group_selection(self):
885886
expected_topk_indices = torch.tensor([[0], [1]])
886887

887888
with patch('torch.topk', return_value=(None, expected_topk_indices)):
888-
result = native_grouped_topk(topk_weights=topk_weights,
889-
num_expert_group=2,
890-
topk_group=1)
889+
result = _native_grouped_topk(topk_weights=topk_weights,
890+
num_expert_group=2,
891+
topk_group=1)
891892

892893
expected_result = torch.tensor(
893894
[[0.1, 0.9, 0.2, 0.8, 0.0, 0.0, 0.0, 0.0],
@@ -900,7 +901,7 @@ def test_single_group(self):
900901
expected_topk_indices = torch.tensor([[0], [0]])
901902

902903
with patch('torch.topk', return_value=(None, expected_topk_indices)):
903-
result = native_grouped_topk(topk_weights=topk_weights,
904-
num_expert_group=1,
905-
topk_group=1)
904+
result = _native_grouped_topk(topk_weights=topk_weights,
905+
num_expert_group=1,
906+
topk_group=1)
906907
self.assertTrue(result.numel() > 0)

vllm_ascend/ops/common_fused_moe.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
UnquantizedFusedMoEMethod
2424

2525
from vllm_ascend.ascend_config import get_ascend_config
26-
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
27-
select_experts)
26+
from vllm_ascend.ops.fused_moe import fused_experts, fused_experts_moge
27+
from vllm_ascend.ops.layers.experts_selector import select_experts
2828
from vllm_ascend.utils import is_310p
2929

3030
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
@@ -58,7 +58,7 @@ def forward_oot(
5858
custom_routing_function: Optional[Callable] = None,
5959
scoring_func: str = "softmax",
6060
e_score_correction_bias: Optional[torch.Tensor] = None,
61-
global_num_experts: Optional[int] = None,
61+
global_num_experts: int = -1,
6262
expert_map: Optional[torch.Tensor] = None,
6363
apply_router_weight_on_input: bool = False,
6464
activation: str = "silu",
@@ -68,7 +68,6 @@ def forward_oot(
6868
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
6969

7070
topk_weights, topk_ids = select_experts(
71-
global_num_experts=global_num_experts,
7271
hidden_states=x,
7372
router_logits=router_logits,
7473
top_k=top_k,
@@ -79,7 +78,7 @@ def forward_oot(
7978
custom_routing_function=custom_routing_function,
8079
scoring_func=scoring_func,
8180
e_score_correction_bias=e_score_correction_bias,
82-
)
81+
global_num_experts=global_num_experts)
8382

8483
if topk_ids.shape[1] < top_k or is_310p():
8584
assert global_num_experts is not None

vllm_ascend/ops/fused_moe.py

Lines changed: 14 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
data_parallel_reduce_scatter
4646
from vllm_ascend.distributed.parallel_state import get_mc2_group
4747
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
48+
from vllm_ascend.ops.layers.experts_selector import select_experts
4849
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
4950
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
5051
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
@@ -863,143 +864,6 @@ def fused_experts(
863864
return final_hidden_states
864865

865866

866-
def native_grouped_topk(
867-
topk_weights: torch.Tensor,
868-
num_expert_group: Optional[int],
869-
topk_group: Optional[int],
870-
):
871-
topk_group = 0 if topk_group is None else topk_group
872-
num_expert_group = 0 if num_expert_group is None else num_expert_group
873-
874-
num_token = topk_weights.shape[0]
875-
grouped_weights = topk_weights.view(num_token, num_expert_group,
876-
-1).max(dim=-1).values
877-
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
878-
k=topk_group,
879-
dim=-1,
880-
sorted=False)[1]
881-
topk_group_mask = torch.zeros_like(grouped_weights)
882-
topk_group_mask.scatter_(1, topk_group_indices, 1)
883-
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
884-
num_token, num_expert_group,
885-
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
886-
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
887-
888-
return topk_weights
889-
890-
891-
def select_experts(
892-
hidden_states: torch.Tensor,
893-
router_logits: torch.Tensor,
894-
top_k: int,
895-
use_grouped_topk: bool,
896-
renormalize: bool,
897-
topk_group: Optional[int] = None,
898-
num_expert_group: Optional[int] = None,
899-
custom_routing_function: Optional[Callable] = None,
900-
scoring_func: str = "softmax",
901-
e_score_correction_bias: Optional[torch.Tensor] = None,
902-
global_num_experts: Optional[torch.Tensor] = None
903-
) -> tuple[torch.Tensor, torch.Tensor]:
904-
"""
905-
Select top-k experts based on router logits.
906-
907-
Args:
908-
hidden_states: Hidden states of shape (num_tokens, hidden_size).
909-
router_logits: Router logits of shape (num_tokens, num_experts).
910-
top_k: Number of experts to select.
911-
use_grouped_topk: Whether to group experts before selecting top-k.
912-
renormalize: Whether to renormalize the routing weights.
913-
topk_group: Number of expert groups to select from.
914-
num_expert_group: Number of experts in each group.
915-
custom_routing_function: Custom routing function.
916-
scoring_func: Scoring function to use.
917-
e_score_correction_bias: Correction bias to apply to expert scores.
918-
919-
Returns:
920-
topk_weights: Routing weights of shape (num_tokens, top_k).
921-
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
922-
923-
Raises:
924-
ValueError: If an unsupported scoring function is provided.
925-
"""
926-
927-
def _renormalize_topk_weights(
928-
topk_weights: torch.Tensor,
929-
renormalize: bool,
930-
):
931-
if renormalize:
932-
topk_weights = topk_weights / topk_weights.sum(dim=-1,
933-
keepdim=True)
934-
return topk_weights
935-
936-
if scoring_func == "softmax":
937-
# NOTE: vLLM use dtype=torch.float here
938-
if not use_grouped_topk and custom_routing_function is None:
939-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
940-
x=router_logits, finished=None, k=top_k)
941-
topk_ids = topk_ids.to(torch.int32)
942-
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
943-
return topk_weights, topk_ids
944-
945-
topk_weights = router_logits.softmax(dim=-1)
946-
elif scoring_func == "sigmoid":
947-
topk_weights = router_logits.sigmoid()
948-
else:
949-
raise ValueError(f"Unsupported scoring function: {scoring_func}")
950-
951-
if use_grouped_topk:
952-
assert topk_group is not None
953-
assert num_expert_group is not None
954-
955-
if e_score_correction_bias is not None:
956-
# Store original scores before applying correction bias. We use biased
957-
# scores for expert selection but original scores for routing weights
958-
original_weights = topk_weights
959-
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
960-
961-
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
962-
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
963-
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
964-
topk_group)
965-
# TODO bfloat16 is not supported in torch.topk with ge graph.
966-
if e_score_correction_bias is not None:
967-
topk_ids = torch.topk(topk_weights.to(torch.float32),
968-
k=top_k,
969-
dim=-1,
970-
sorted=False)[1]
971-
# Use original unbiased scores for the routing weights
972-
topk_weights = original_weights.gather(1, topk_ids)
973-
else:
974-
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
975-
k=top_k,
976-
dim=-1,
977-
sorted=False)
978-
topk_ids = topk_ids.to(torch.int32)
979-
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
980-
return topk_weights, topk_ids
981-
982-
if custom_routing_function is not None:
983-
topk_weights, topk_ids = custom_routing_function(
984-
hidden_states=hidden_states,
985-
gating_output=router_logits,
986-
topk=top_k,
987-
renormalize=renormalize,
988-
global_num_experts=global_num_experts)
989-
# Required by npu_moe_init_routing
990-
topk_ids = topk_ids.to(torch.int32)
991-
return topk_weights, topk_ids
992-
993-
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
994-
topk_weights = topk_weights.to(hidden_states.dtype)
995-
996-
# Required by npu_moe_init_routing
997-
topk_ids = topk_ids.to(torch.int32)
998-
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
999-
1000-
return topk_weights, topk_ids
1001-
1002-
1003867
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
1004868

1005869
def __init__(self, moe: FusedMoEConfig = None):
@@ -1054,36 +918,19 @@ def apply(
1054918
**kwargs,
1055919
) -> torch.Tensor:
1056920

1057-
is_deepseek_v3_r1 = global_num_experts == 256
1058-
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
1059-
if is_deepseek_v3_r1:
1060-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
1061-
router_logits,
1062-
k=top_k, # topk currently is 8
1063-
bias=e_score_correction_bias,
1064-
k_group=topk_group, # fix: 4
1065-
group_count=num_expert_group, # fix 8
1066-
group_select_mode=
1067-
1, # 0: the maximum in the group; 1: topk2.sum(fix)
1068-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
1069-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
1070-
# out_flag=False, # todo new api; should the third output be output
1071-
# y2_flag=False, # old api; should the third output be output
1072-
routed_scaling_factor=1,
1073-
eps=float(1e-20))
1074-
else:
1075-
topk_weights, topk_ids = select_experts(
1076-
hidden_states=x,
1077-
router_logits=router_logits,
1078-
top_k=top_k,
1079-
use_grouped_topk=use_grouped_topk,
1080-
renormalize=renormalize,
1081-
topk_group=topk_group,
1082-
num_expert_group=num_expert_group,
1083-
custom_routing_function=custom_routing_function,
1084-
scoring_func=scoring_func,
1085-
e_score_correction_bias=e_score_correction_bias,
1086-
)
921+
topk_weights, topk_ids = select_experts(
922+
hidden_states=x,
923+
router_logits=router_logits,
924+
top_k=top_k,
925+
use_grouped_topk=use_grouped_topk,
926+
renormalize=renormalize,
927+
topk_group=topk_group,
928+
num_expert_group=num_expert_group,
929+
custom_routing_function=custom_routing_function,
930+
scoring_func=scoring_func,
931+
e_score_correction_bias=e_score_correction_bias,
932+
global_num_experts=global_num_experts,
933+
is_unquantized=True)
1087934

1088935
topk_weights = topk_weights.to(x.dtype)
1089936
# this is a naive implementation for experts load balance so as

vllm_ascend/ops/layers/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)