4949from vllm .model_executor .sampling_metadata import SamplingMetadata
5050from vllm .sequence import IntermediateTensors
5151
52- from vllm_ascend .distributed . parallel_state import get_ep_group
52+ from vllm_ascend .utils import is_310p
5353
5454logger = init_logger (__name__ )
5555
@@ -95,41 +95,69 @@ def forward(self, x):
9595 return x
9696
9797
98- class PanguProMoESparseMoeBlock ( nn . Module ):
98+ def topk_wrapper ( num_voted_experts ):
9999
100- @staticmethod
101100 def pangu_group8_topk (
102101 hidden_states : torch .Tensor ,
103102 gating_output : torch .Tensor ,
104103 topk : int ,
105- renormalize : bool ,
104+ renormalize : bool = False ,
106105 num_expert_group : int = 0 ,
107106 topk_group : int = 0 ,
108107 global_num_experts : int = 0 ,
109108 ):
110- ep_size = get_ep_group ().world_size
111- local_num_experts = global_num_experts // ep_size
112- local_num_group = topk // ep_size
113- router_scale = _ROUTER_SCALE .squeeze () # type: ignore
114- scores = F .softmax (gating_output , dim = 1 )
115- scores = scores [...,
116- get_ep_group ().rank_in_group *
117- local_num_experts :(get_ep_group ().rank_in_group + 1 ) *
118- local_num_experts ]
119-
120- router_weights = router_scale [get_ep_group ().rank_in_group *
121- local_num_experts :
122- (get_ep_group ().rank_in_group + 1 ) *
123- local_num_experts ]
124- topk_weights , topk_ids = torch .max (scores .view (scores .shape [0 ],
125- local_num_group , - 1 ),
126- dim = - 1 )
127- bias = torch .arange (0 ,
128- local_num_experts ,
129- topk ,
130- device = scores .device ,
131- dtype = torch .int32 ).unsqueeze (0 )
132- topk_ids = topk_ids .to (torch .int32 ) + bias
109+ scores = F .softmax (gating_output , dim = 1 , dtype = torch .float16 )
110+ num_tokens = scores .shape [0 ]
111+ router_weights = _ROUTER_SCALE .squeeze ( # type: ignore
112+ ).to (torch .float16 )
113+
114+ if num_voted_experts == 8 :
115+ # use original topk
116+ topk_weights , topk_ids = torch .max (scores .view (
117+ scores .shape [0 ], topk , - 1 ),
118+ dim = - 1 )
119+ bias = torch .arange (0 ,
120+ global_num_experts ,
121+ topk ,
122+ device = scores .device ,
123+ dtype = torch .int32 ).unsqueeze (0 )
124+ topk_ids = topk_ids .to (torch .int32 ) + bias
125+
126+ else :
127+ experts_per_group = global_num_experts // topk
128+ group_expert_indices = torch .arange (experts_per_group ,
129+ dtype = torch .int32 ,
130+ device = scores .device ).view (
131+ 1 , 1 , - 1 )
132+ group_expert_offset = (
133+ torch .arange (topk , dtype = torch .int32 , device = scores .device ) *
134+ experts_per_group ).unsqueeze (0 )
135+ expert_index_range = torch .arange (experts_per_group ,
136+ dtype = torch .int32 ,
137+ device = scores .device )
138+
139+ scores_grouped = scores .view (num_tokens , topk , experts_per_group )
140+ best_expert_idx = torch .argmax (scores_grouped ,
141+ dim = 2 ) # (num_tokens, num_groups)
142+ vote_mask = (best_expert_idx .unsqueeze (- 1 ).to (
143+ torch .int32 ) == group_expert_indices ).to (torch .float16 )
144+
145+ expert_vote_freq = vote_mask .sum (dim = 0 )
146+
147+ sorted_indices = torch .argsort (expert_vote_freq ,
148+ dim = 1 ,
149+ descending = True ).to (torch .int32 )
150+ topk_experts = sorted_indices [:, :num_voted_experts ]
151+ keep_mask = ((
152+ topk_experts .unsqueeze (- 1 ) == expert_index_range ).any (
153+ dim = 1 )).unsqueeze (0 )
154+
155+ masked_scores = torch .where (keep_mask , scores_grouped , 0 )
156+
157+ topk_weights , best_pos_in_group = masked_scores .max (dim = 2 )
158+ best_pos_in_group = best_pos_in_group .to (torch .int32 )
159+ topk_ids = (best_pos_in_group + group_expert_offset ).to (
160+ torch .int32 )
133161
134162 flatten_topk_ids = topk_ids .view (- 1 )
135163 router_weights = router_weights .index_select (0 , flatten_topk_ids ).view (
@@ -138,6 +166,11 @@ def pangu_group8_topk(
138166
139167 return topk_weights , topk_ids
140168
169+ return pangu_group8_topk
170+
171+
172+ class PanguProMoESparseMoeBlock (nn .Module ):
173+
141174 def __init__ (
142175 self ,
143176 config : PretrainedConfig ,
@@ -153,23 +186,23 @@ def __init__(
153186 f"Tensor parallel size { self .tp_size } is greater than "
154187 f"the number of experts { config .num_experts } ." )
155188
156- self .local_num_group = config .num_experts_per_tok // get_ep_group (
157- ).world_size
158189 self .num_experts_per_tok = config .num_experts_per_tok
159- self .local_num_experts = config .num_experts // get_ep_group (
160- ).world_size
161190 self .router_scale = torch .nn .Parameter (
162191 torch .ones ((1 , self .num_experts )))
163192
193+ # on 300I Duo platform, we find that num_voted_experts set to 5 achieves
194+ # good performance without sacrifice too much accuracy. for other platform,
195+ # this is set to 8 to use original pangu grouped topk.
196+ num_voted_experts = 5 if is_310p () else 8
197+
164198 self .experts = FusedMoE (
165199 num_experts = config .num_experts ,
166200 top_k = config .num_experts_per_tok ,
167201 hidden_size = config .hidden_size ,
168202 intermediate_size = config .moe_intermediate_size ,
169203 reduce_results = False ,
170204 quant_config = quant_config ,
171- custom_routing_function = PanguProMoESparseMoeBlock .
172- pangu_group8_topk ,
205+ custom_routing_function = topk_wrapper (num_voted_experts ),
173206 prefix = f"{ prefix } .experts" ,
174207 )
175208
0 commit comments