Skip to content

Commit 3f11265

Browse files
author
yangcheng (AJ)
committed
refactor
1 parent 4c8842d commit 3f11265

File tree

4 files changed

+111
-65
lines changed

4 files changed

+111
-65
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
data_parallel_reduce_scatter
4646
from vllm_ascend.distributed.parallel_state import get_mc2_group
4747
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
48+
from vllm_ascend.ops.moe_layer.select_experts import UnquantizedSelectExperts
49+
from vllm_ascend.ops.moe_layer.config import SelectExpertConfig
4850
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
4951
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
5052
get_all_reduce_merge_state,
@@ -1034,6 +1036,8 @@ def __init__(self, moe: FusedMoEConfig = None):
10341036
except AttributeError:
10351037
self.moe_all_to_all_group_name = None
10361038

1039+
self.select_experts = UnquantizedSelectExperts()
1040+
10371041
def process_weights_after_loading(self, layer):
10381042
super(UnquantizedFusedMoEMethod,
10391043
self).process_weights_after_loading(layer)
@@ -1065,41 +1069,7 @@ def apply(
10651069
**kwargs,
10661070
) -> torch.Tensor:
10671071

1068-
is_deepseek_v3_r1 = global_num_experts == 256
1069-
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
1070-
if is_deepseek_v3_r1:
1071-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
1072-
router_logits,
1073-
k=top_k, # topk当前写8
1074-
bias=e_score_correction_bias,
1075-
k_group=topk_group, # fix: 4
1076-
group_count=num_expert_group, # fix 8
1077-
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
1078-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
1079-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
1080-
# out_flag=False, # todo new api; 第三个输出是否输出
1081-
# y2_flag=False, # old api; 第三个输出是否输出
1082-
routed_scaling_factor=1,
1083-
eps=float(1e-20))
1084-
elif SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
1085-
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
1086-
hidden_states=x,
1087-
router_logits=router_logits,
1088-
top_k=top_k,
1089-
renormalize=renormalize)
1090-
else:
1091-
topk_weights, topk_ids = select_experts(
1092-
hidden_states=x,
1093-
router_logits=router_logits,
1094-
top_k=top_k,
1095-
use_grouped_topk=use_grouped_topk,
1096-
renormalize=renormalize,
1097-
topk_group=topk_group,
1098-
num_expert_group=num_expert_group,
1099-
custom_routing_function=custom_routing_function,
1100-
scoring_func=scoring_func,
1101-
e_score_correction_bias=e_score_correction_bias,
1102-
)
1072+
topk_weights, topk_ids = self.select_experts(router_logits, x)
11031073

11041074
topk_weights = topk_weights.to(x.dtype)
11051075
# this is a naive implementation for experts load balance so as
@@ -1268,6 +1238,20 @@ def __init__(
12681238
in_dtype=params_dtype,
12691239
quant_config=quant_config)
12701240

1241+
select_experts_dict = {
1242+
'top_k' : top_k,
1243+
'e_score_correction_bias' : e_score_correction_bias,
1244+
'topk_group' : topk_group,
1245+
'num_expert_group' : num_expert_group,
1246+
'custom_routing_function' : custom_routing_function,
1247+
'scoring_func' : scoring_func,
1248+
'global_num_experts' : self.global_num_experts,
1249+
'use_grouped_topk' : use_grouped_topk,
1250+
'renormalize' : renormalize,
1251+
}
1252+
1253+
SelectExpertConfig(select_experts_dict)
1254+
12711255
if quant_config is None:
12721256
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
12731257
else:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
class SelectExpertConfig:
3+
def __init__(self, config):
4+
self.config = config
5+
6+
@staticmethod
7+
def get_config():
8+
return self.config
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from abc import ABC, abstractmethod
2+
import torch_npu
3+
from vllm_ascend.ops.fused_moe import select_experts
4+
import vllm_ascend.envs as envs_ascend
5+
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
6+
7+
8+
class BaseSelectExperts(ABC):
9+
10+
def __init__(self):
11+
need_param = SelectExpertConfig.get_config
12+
self.top_k = need_param["top_k"]
13+
self.e_score_correction_bias = need_param["e_score_correction_bias"]
14+
self.topk_group = need_param["topk_group"]
15+
self.num_expert_group = need_param["num_expert_group"]
16+
self.custom_routing_function = need_param["custom_routing_function"]
17+
self.scoring_func = need_param["scoring_func"]
18+
self.global_num_experts = need_param["global_num_experts"]
19+
self.use_grouped_topk = need_param['use_grouped_topk']
20+
self.renormalize = need_param['renormalize']
21+
22+
def forward(self, router_logits: torch.Tensor, x: torch.Tensor):
23+
if self.global_num_experts == 256:
24+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
25+
router_logits,
26+
k=self.top_k, # topk当前写8
27+
bias=self.e_score_correction_bias,
28+
k_group=self.topk_group, # fix: 4
29+
group_count=self.num_expert_group, # fix 8
30+
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
31+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
32+
norm_type=1, # 0: softmax; 1: sigmoid(fix)
33+
# out_flag=False, # todo new api; 第三个输出是否输出
34+
# y2_flag=False, # old api; 第三个输出是否输出
35+
routed_scaling_factor=1,
36+
eps=float(1e-20))
37+
else:
38+
topk_weights, topk_ids = select_experts(
39+
hidden_states=x,
40+
router_logits=router_logits,
41+
top_k=self.top_k,
42+
use_grouped_topk=self.use_grouped_topk,
43+
renormalize=self.renormalize,
44+
topk_group=self.topk_group,
45+
num_expert_group=self.num_expert_group,
46+
custom_routing_function=self.custom_routing_function,
47+
scoring_func=self.scoring_func,
48+
e_score_correction_bias=self.e_score_correction_bias,
49+
)
50+
return topk_weights, topk_ids
51+
52+
53+
class UnquantizedSelectExperts(BaseSelectExperts):
54+
def __init__(self):
55+
super().__init__()
56+
57+
def forward(self, router_logits: torch.Tensor, x: torch.Tensor):
58+
if SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
59+
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
60+
hidden_states=x,
61+
router_logits=router_logits,
62+
top_k=self.top_k,
63+
renormalize=self.renormalize)
64+
else:
65+
topk_weights, topk_ids = super().forward(router_logits, x)
66+
67+
return topk_weights, topk_ids
68+
69+
70+
class QuantizedSelectExperts(BaseSelectExperts):
71+
def __init__(self):
72+
super().__init__()
73+
74+
def forward(self, router_logits: torch.Tensor, x: torch.Tensor):
75+
76+
return super().forward(router_logits, x)
77+
78+
79+

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from vllm_ascend.ascend_config import get_ascend_config
2828
from vllm_ascend.ascend_forward_context import FusedMoEState
2929
from vllm_ascend.distributed.parallel_state import get_mc2_group
30+
from vllm_ascend.ops.moe_layer.select_experts import QuantizedSelectExperts
3031
from vllm_ascend.ops.fused_moe import select_experts
3132
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
3233
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
@@ -766,6 +767,8 @@ def __init__(self):
766767
except AttributeError:
767768
self.moe_all_to_all_group_name = ""
768769

770+
self.select_experts = QuantizedSelectExperts()
771+
769772
@staticmethod
770773
def get_weight(num_experts: int, intermediate_size_per_partition: int,
771774
hidden_sizes: int,
@@ -835,36 +838,8 @@ def apply(
835838
assert router_logits.shape[
836839
1] == global_num_experts, "Number of global experts mismatch"
837840

838-
is_deepseek_v3_r1 = global_num_experts == 256
839-
840-
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
841-
if is_deepseek_v3_r1:
842-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
843-
router_logits,
844-
k=top_k, # topk当前写8
845-
bias=e_score_correction_bias,
846-
k_group=topk_group, # fix: 4
847-
group_count=num_expert_group, # fix 8
848-
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
849-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
850-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
851-
# out_flag=False, # todo new api; 第三个输出是否输出
852-
# y2_flag=False, # old api; 第三个输出是否输出
853-
routed_scaling_factor=1,
854-
eps=float(1e-20))
855-
else:
856-
topk_weights, topk_ids = select_experts(
857-
hidden_states=x,
858-
router_logits=router_logits,
859-
top_k=top_k,
860-
use_grouped_topk=use_grouped_topk,
861-
renormalize=renormalize,
862-
topk_group=topk_group,
863-
num_expert_group=num_expert_group,
864-
custom_routing_function=custom_routing_function,
865-
scoring_func=scoring_func,
866-
e_score_correction_bias=e_score_correction_bias,
867-
)
841+
842+
topk_weights, topk_ids = self.select_experts(router_logits, x)
868843

869844
fused_moe_state = get_forward_context().fused_moe_state
870845
shared_gate_up, shared_dequant_scale = None, None

0 commit comments

Comments
 (0)