3131 dispose_tensor , get_fused_moe_state )
3232
3333
34+ def apply_mlp_decode (hidden_states : torch .Tensor ,
35+ w1 : torch .Tensor ,
36+ w1_scale : torch .Tensor ,
37+ w2 : torch .Tensor ,
38+ w2_scale : torch .Tensor ,
39+ group_list : torch .Tensor ,
40+ dynamic_scale : torch .Tensor = None ,
41+ group_list_type : int = 1 ) -> torch .Tensor :
42+ """
43+ apply MLP: gate_up_proj -> swiglu -> down_proj
44+ Args:
45+ hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
46+ w1: expert weights1 with shape
47+ (num_experts, hidden_size, intermediate_size * 2)
48+ w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
49+ w2: expert weights2 with shape
50+ (num_experts, intermediate_size, hidden_size)
51+ w2_scale: weights2 scale with shape (num_experts, hidden_size)
52+ group_list: number of tokens for each expert, follow cumsum mode, and
53+ with shape (num_experts).
54+ transpose_weight:
55+ w1: (num_experts, intermediate_size * 2, hidden_size) ->
56+ (num_experts, hidden_size, intermediate_size * 2)
57+ w2: (num_experts, hidden_size, intermediate_size) ->
58+ (num_experts, intermediate_size, hidden_size)
59+ Returns:
60+ hidden_states: output hidden states after MLP.
61+ """
62+
63+ if dynamic_scale is None :
64+ unquantized_hidden_states = hidden_states
65+ hidden_states , pertoken_scale = torch_npu .npu_dynamic_quant (
66+ hidden_states )
67+ # Dispose the original unquantized hidden states
68+ # to save npu memory because they're no longer used.
69+ dispose_tensor (unquantized_hidden_states )
70+ else :
71+ pertoken_scale = dynamic_scale
72+
73+ # gmm1: gate_up_proj
74+ hidden_states = torch_npu .npu_grouped_matmul (
75+ x = [hidden_states ],
76+ weight = [w1 ],
77+ split_item = 3 ,
78+ group_list_type = group_list_type ,
79+ group_type = 0 ,
80+ group_list = group_list ,
81+ output_dtype = torch .int32 )[0 ]
82+
83+ # act_fn: swiglu
84+ hidden_states , swiglu_out_scale = torch_npu .npu_dequant_swiglu_quant (
85+ x = hidden_states ,
86+ weight_scale = w1_scale ,
87+ activation_scale = pertoken_scale ,
88+ bias = None ,
89+ quant_scale = None ,
90+ quant_offset = None ,
91+ group_index = group_list ,
92+ activate_left = True ,
93+ quant_mode = 1 ,
94+ )
95+
96+ # gmm2: down_proj
97+ hidden_states = torch_npu .npu_grouped_matmul (
98+ x = [hidden_states ],
99+ weight = [w2 ],
100+ scale = [w2_scale ],
101+ per_token_scale = [swiglu_out_scale ],
102+ split_item = 2 ,
103+ group_list_type = group_list_type ,
104+ group_type = 0 ,
105+ group_list = group_list ,
106+ output_dtype = w2_scale .dtype )[0 ]
107+ return hidden_states
108+
109+
34110def apply_mlp (hidden_states : torch .Tensor ,
35111 w1 : torch .Tensor ,
36112 w1_scale : torch .Tensor ,
@@ -118,6 +194,8 @@ def fused_experts_with_mc2(
118194 log2phy : torch .Tensor = None ,
119195 global_redundant_expert_num : int = 0 ,
120196 shared_experts : Optional [Any ] = None ,
197+ shared_gate_up : Optional [Any ] = None ,
198+ shared_dequant_scale : Optional [Any ] = None ,
121199) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
122200 if log2phy is not None :
123201 topk_ids = log2phy [topk_ids ]
@@ -165,19 +243,19 @@ def fused_experts_with_mc2(
165243
166244 if shared_experts is not None :
167245 with npu_stream_switch ("moe_secondary" , 0 ):
168- npu_wait_tensor (hidden_states , topk_weights )
169- shared_gate_up , _ = shared_experts .gate_up_proj ( hidden_states )
170- npu_wait_tensor (shared_gate_up [ 0 ], expand_x )
171- shared_act = shared_experts . act_fn ( shared_gate_up )
246+ npu_wait_tensor (shared_gate_up , expand_x )
247+ shared_act_out = shared_experts .act_fn (
248+ (shared_gate_up , shared_dequant_scale ) )
249+ shared_act , swiglu_out_scale = shared_act_out [ 0 ], shared_act_out [ 1 ]
172250
173251 # `expand_x` will be disposed in the `apply_mlp` function
174- down_out_list = apply_mlp (expand_x ,
175- w1 ,
176- w1_scale ,
177- w2 ,
178- w2_scale ,
179- expert_token_nums ,
180- dynamic_scale = dynamic_scale )
252+ down_out_list = apply_mlp_decode (expand_x ,
253+ w1 ,
254+ w1_scale ,
255+ w2 ,
256+ w2_scale ,
257+ expert_token_nums ,
258+ dynamic_scale = dynamic_scale )
181259
182260 # moeCombine
183261 kwargs_mc2 = {
@@ -213,8 +291,9 @@ def fused_experts_with_mc2(
213291 return hidden_states
214292 else :
215293 with npu_stream_switch ("moe_secondary" , 0 ):
216- npu_wait_tensor (shared_act [0 ], down_out_list )
217- shared_output , _ = shared_experts .down_proj (shared_act )
294+ npu_wait_tensor (shared_act , down_out_list )
295+ shared_output , _ = shared_experts .down_proj (
296+ (shared_act , swiglu_out_scale ))
218297 return hidden_states , shared_output
219298
220299
@@ -708,6 +787,8 @@ def apply(
708787 log2phy : torch .Tensor = None ,
709788 global_redundant_expert_num : int = 0 ,
710789 shared_experts : Optional [Any ] = None ,
790+ quantized_x_for_share : Optional [Any ] = None ,
791+ dynamic_scale_for_share : Optional [Any ] = None ,
711792 ** kwargs ,
712793 ) -> torch .Tensor :
713794 assert router_logits .shape [
@@ -744,16 +825,24 @@ def apply(
744825 e_score_correction_bias = e_score_correction_bias ,
745826 )
746827
828+ fused_moe_state = get_fused_moe_state (self .ep_group .world_size ,
829+ is_prefill , is_deepseek_v3_r1 )
830+ shared_gate_up , shared_dequant_scale = None , None
831+ if shared_experts is not None and fused_moe_state == FusedMoEState .MC2 :
832+ with npu_stream_switch ("moe_secondary" , 0 ):
833+ npu_wait_tensor (quantized_x_for_share , router_logits )
834+ share_up_out , _ = shared_experts .gate_up_proj (
835+ (quantized_x_for_share , dynamic_scale_for_share ))
836+ shared_gate_up , shared_dequant_scale = share_up_out [
837+ 0 ], share_up_out [1 ]
838+
747839 # this is a naive implementation for experts load balance so as
748840 # to avoid accumulating too much tokens on a single rank.
749841 # currently it is only activated when doing profile runs.
750842 if enable_force_load_balance :
751843 topk_ids = torch .randint_like (topk_ids , 0 , global_num_experts )
752844
753845 topk_weights = topk_weights .to (x .dtype )
754-
755- fused_moe_state = get_fused_moe_state (self .ep_group .world_size ,
756- is_prefill , is_deepseek_v3_r1 )
757846 if fused_moe_state == FusedMoEState .AllGatherEP :
758847 return fused_experts_with_allgather (
759848 hidden_states = x ,
@@ -770,7 +859,7 @@ def apply(
770859 hidden_states = x ,
771860 w1 = layer .w13_weight ,
772861 w2 = layer .w2_weight ,
773- w1_scale = layer .w13_weight_scale ,
862+ w1_scale = layer .w13_weight_scale_fp32 ,
774863 w2_scale = layer .w2_weight_scale ,
775864 topk_weights = topk_weights ,
776865 topk_ids = topk_ids ,
@@ -779,7 +868,9 @@ def apply(
779868 moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
780869 log2phy = log2phy ,
781870 global_redundant_expert_num = global_redundant_expert_num ,
782- shared_experts = shared_experts )
871+ shared_experts = shared_experts ,
872+ shared_gate_up = shared_gate_up ,
873+ shared_dequant_scale = shared_dequant_scale )
783874 elif fused_moe_state in [
784875 FusedMoEState .AllGather , FusedMoEState .NaiveMulticast
785876 ]:
@@ -822,6 +913,8 @@ def process_weights_after_loading(self, layer):
822913 torch_npu .npu_format_cast_ (layer .w2_weight , ACL_FORMAT_FRACTAL_NZ )
823914 layer .w13_weight_scale .data = layer .w13_weight_scale .data .view (
824915 layer .w13_weight_scale .data .shape [0 ], - 1 )
916+ layer .w13_weight_scale_fp32 = layer .w13_weight_scale .data .to (
917+ torch .float32 )
825918 layer .w13_weight_offset .data = layer .w13_weight_offset .data .view (
826919 layer .w13_weight_offset .data .shape [0 ], - 1 )
827920 layer .w2_weight_scale .data = layer .w2_weight_scale .data .view (
0 commit comments