@@ -275,9 +275,9 @@ def apply(
275275 e_score_correction_bias = e_score_correction_bias ,
276276 global_num_experts = global_num_experts )
277277
278- fused_moe_state = get_forward_context ().moe_comm_method_name
278+ fused_moe_state = get_forward_context ().fused_moe_state
279279 shared_gate_up , shared_dequant_scale = None , None
280- if shared_experts is not None and fused_moe_state == "mc2commimpl" :
280+ if shared_experts is not None and fused_moe_state == FusedMoEState . MC2 :
281281 share_up_out , _ = shared_experts .gate_up_proj (
282282 (quantized_x_for_share , dynamic_scale_for_share ))
283283 shared_gate_up , shared_dequant_scale = share_up_out [
@@ -290,10 +290,8 @@ def apply(
290290 topk_ids = torch .randint_like (topk_ids , 0 , global_num_experts )
291291
292292 topk_weights = topk_weights .to (x .dtype )
293-
294- moe_comm_method = get_forward_context ().moe_comm_method
295- print ("using w4a8" )
296- return moe_comm_method .fused_experts (
293+
294+ return unified_fused_experts_eager (
297295 hidden_states = x ,
298296 w1 = layer .w13_weight ,
299297 w2 = layer .w2_weight ,
@@ -304,34 +302,14 @@ def apply(
304302 topk_weights = topk_weights ,
305303 topk_ids = topk_ids ,
306304 row_idx = row_idx ,
307- use_int4_w4a8 = True ,
308305 expert_map = expert_map ,
309306 log2phy = log2phy ,
310307 global_redundant_expert_num = global_redundant_expert_num ,
311308 shared_experts = shared_experts ,
312309 shared_gate_up = shared_gate_up ,
313- shared_dequant_scale = shared_dequant_scale
314- )
315-
316- # return unified_fused_experts_eager(
317- # hidden_states=x,
318- # w1=layer.w13_weight,
319- # w2=layer.w2_weight,
320- # w1_scale=layer.w13_weight_scale_second,
321- # w2_scale=layer.w2_weight_scale_second,
322- # w1_scale_bias=layer.w13_scale_bias,
323- # w2_scale_bias=layer.w2_scale_bias,
324- # topk_weights=topk_weights,
325- # topk_ids=topk_ids,
326- # row_idx=row_idx,
327- # expert_map=expert_map,
328- # log2phy=log2phy,
329- # global_redundant_expert_num=global_redundant_expert_num,
330- # shared_experts=shared_experts,
331- # shared_gate_up=shared_gate_up,
332- # shared_dequant_scale=shared_dequant_scale,
333- # mc2_mask=kwargs.get("mc2_mask", None),
334- # with_quant=True)
310+ shared_dequant_scale = shared_dequant_scale ,
311+ mc2_mask = kwargs .get ("mc2_mask" , None ),
312+ with_quant = True )
335313
336314 def process_scale (self , weight : torch .Tensor , scale , per_group_scale ):
337315 group_num , k , n = weight .shape
0 commit comments