@@ -330,37 +330,5 @@ def forward_oot(
330330 top_k = top_k ,
331331 expert_map = expert_map )
332332
333- def forward (self , hidden_states : torch .Tensor ,
334- router_logits : torch .Tensor , top_k = None ):
335- assert self .quant_method is not None
336-
337- if top_k :
338- real_top_k = top_k
339- else :
340- real_top_k = self .top_k
341-
342- # Matrix multiply.
343- final_hidden_states = self .quant_method .apply (
344- layer = self ,
345- x = hidden_states ,
346- router_logits = router_logits ,
347- top_k = real_top_k ,
348- renormalize = self .renormalize ,
349- use_grouped_topk = self .use_grouped_topk ,
350- global_num_experts = self .num_experts ,
351- expert_map = self .expert_map ,
352- topk_group = self .topk_group ,
353- num_expert_group = self .num_expert_group ,
354- custom_routing_function = self .custom_routing_function ,
355- scoring_func = self .scoring_func ,
356- e_score_correction_bias = self .e_score_correction_bias )
357-
358- if self .reduce_results and self .tp_size > 1 :
359- final_hidden_states = tensor_model_parallel_all_reduce (
360- final_hidden_states )
361-
362- return final_hidden_states
363-
364333
365334UnquantizedFusedMoEMethod .forward_oot = forward_oot
366- FusedMoE .forward = forward
0 commit comments