1515# limitations under the License.
1616#
1717
18- from typing import Any , Callable , Dict , Optional , Tuple , Union
18+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1919
2020import torch
2121import torch .distributed as dist
3131 npu_stream_switch , npu_wait_tensor )
3232
3333
34+ def apply_mlp_decode (hidden_states_wrapper : List [torch .Tensor ],
35+ w1 : torch .Tensor ,
36+ w1_scale : torch .Tensor ,
37+ w2 : torch .Tensor ,
38+ w2_scale : torch .Tensor ,
39+ group_list : torch .Tensor ,
40+ dynamic_scale : torch .Tensor = None ,
41+ group_list_type : int = 1 ) -> torch .Tensor :
42+ """
43+ apply MLP: gate_up_proj -> swiglu -> down_proj
44+ Args:
45+ hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
46+ w1: expert weights1 with shape
47+ (num_experts, hidden_size, intermediate_size * 2)
48+ w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
49+ w2: expert weights2 with shape
50+ (num_experts, intermediate_size, hidden_size)
51+ w2_scale: weights2 scale with shape (num_experts, hidden_size)
52+ group_list: number of tokens for each expert, follow cumsum mode, and
53+ with shape (num_experts).
54+ transpose_weight:
55+ w1: (num_experts, intermediate_size * 2, hidden_size) ->
56+ (num_experts, hidden_size, intermediate_size * 2)
57+ w2: (num_experts, hidden_size, intermediate_size) ->
58+ (num_experts, intermediate_size, hidden_size)
59+ Returns:
60+ hidden_states: output hidden states after MLP.
61+ """
62+
63+ assert len (hidden_states_wrapper ) == 1
64+ hidden_states = hidden_states_wrapper .pop ()
65+ if dynamic_scale is None :
66+ hidden_states , pertoken_scale = torch_npu .npu_dynamic_quant (
67+ hidden_states )
68+ else :
69+ pertoken_scale = dynamic_scale
70+
71+ # gmm1: gate_up_proj
72+ hidden_states = torch_npu .npu_grouped_matmul (
73+ x = [hidden_states ],
74+ weight = [w1 ],
75+ split_item = 3 ,
76+ group_list_type = group_list_type ,
77+ group_type = 0 ,
78+ group_list = group_list ,
79+ output_dtype = torch .int32 )[0 ]
80+
81+ # act_fn: swiglu
82+ hidden_states , swiglu_out_scale = torch_npu .npu_dequant_swiglu_quant (
83+ x = hidden_states ,
84+ weight_scale = w1_scale ,
85+ activation_scale = pertoken_scale ,
86+ bias = None ,
87+ quant_scale = None ,
88+ quant_offset = None ,
89+ group_index = group_list ,
90+ activate_left = True ,
91+ quant_mode = 1 ,
92+ )
93+
94+ # gmm2: down_proj
95+ hidden_states = torch_npu .npu_grouped_matmul (
96+ x = [hidden_states ],
97+ weight = [w2 ],
98+ scale = [w2_scale ],
99+ per_token_scale = [swiglu_out_scale ],
100+ split_item = 2 ,
101+ group_list_type = group_list_type ,
102+ group_type = 0 ,
103+ group_list = group_list ,
104+ output_dtype = w2_scale .dtype )[0 ]
105+ return hidden_states
106+
107+
34108def apply_mlp (hidden_states : torch .Tensor ,
35109 w1 : torch .Tensor ,
36110 w1_scale : torch .Tensor ,
@@ -118,6 +192,8 @@ def fused_experts_with_mc2(
118192 log2phy : torch .Tensor = None ,
119193 global_redundant_expert_num : int = 0 ,
120194 shared_experts : Optional [Any ] = None ,
195+ quantized_x_for_share : Optional [Any ] = None ,
196+ dynamic_scale_for_share : Optional [Any ] = None ,
121197) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
122198 if log2phy is not None :
123199 topk_ids = log2phy [topk_ids ]
@@ -165,19 +241,19 @@ def fused_experts_with_mc2(
165241
166242 if shared_experts is not None :
167243 with npu_stream_switch ("moe_secondary" , 0 ):
168- npu_wait_tensor (hidden_states , topk_weights )
169- shared_gate_up , _ = shared_experts .gate_up_proj ( hidden_states )
170- npu_wait_tensor ( shared_gate_up [ 0 ], expand_x )
171- shared_act = shared_experts . act_fn ( shared_gate_up )
244+ npu_wait_tensor (quantized_x_for_share , expand_x )
245+ shared_act_out = shared_experts .act_fn (
246+ ( quantized_x_for_share , dynamic_scale_for_share ) )
247+ shared_act , swiglu_out_scale = shared_act_out [ 0 ], shared_act_out [ 1 ]
172248
173249 # `expand_x` will be disposed in the `apply_mlp` function
174- down_out_list = apply_mlp ( expand_x ,
175- w1 ,
176- w1_scale ,
177- w2 ,
178- w2_scale ,
179- expert_token_nums ,
180- dynamic_scale = dynamic_scale )
250+ down_out_list = apply_mlp_decode ([ expand_x ] ,
251+ w1 ,
252+ w1_scale ,
253+ w2 ,
254+ w2_scale ,
255+ expert_token_nums ,
256+ dynamic_scale = dynamic_scale )
181257
182258 # moeCombine
183259 kwargs_mc2 = {
@@ -213,8 +289,9 @@ def fused_experts_with_mc2(
213289 return hidden_states
214290 else :
215291 with npu_stream_switch ("moe_secondary" , 0 ):
216- npu_wait_tensor (shared_act [0 ], down_out_list )
217- shared_output , _ = shared_experts .down_proj (shared_act )
292+ npu_wait_tensor (shared_act , down_out_list )
293+ shared_output , _ = shared_experts .down_proj (
294+ (shared_act , swiglu_out_scale ))
218295 return hidden_states , shared_output
219296
220297
@@ -708,6 +785,8 @@ def apply(
708785 log2phy : torch .Tensor = None ,
709786 global_redundant_expert_num : int = 0 ,
710787 shared_experts : Optional [Any ] = None ,
788+ quantized_x_for_share : Optional [Any ] = None ,
789+ dynamic_scale_for_share : Optional [Any ] = None ,
711790 ** kwargs ,
712791 ) -> torch .Tensor :
713792 assert router_logits .shape [
@@ -744,6 +823,15 @@ def apply(
744823 e_score_correction_bias = e_score_correction_bias ,
745824 )
746825
826+ shared_gate_up , shared_dequant_scale = None , None
827+ if shared_experts is not None and fused_moe_state == FusedMoEState .MC2 :
828+ with npu_stream_switch ("moe_secondary" , 0 ):
829+ npu_wait_tensor (quantized_x_for_share , router_logits )
830+ share_up_out , _ = shared_experts .gate_up_proj (
831+ (quantized_x_for_share , dynamic_scale_for_share ))
832+ shared_gate_up , shared_dequant_scale = share_up_out [
833+ 0 ], share_up_out [1 ]
834+
747835 # this is a naive implementation for experts load balance so as
748836 # to avoid accumulating too much tokens on a single rank.
749837 # currently it is only activated when doing profile runs.
@@ -770,7 +858,7 @@ def apply(
770858 hidden_states = x ,
771859 w1 = layer .w13_weight ,
772860 w2 = layer .w2_weight ,
773- w1_scale = layer .w13_weight_scale ,
861+ w1_scale = layer .w13_weight_scale_fp32 ,
774862 w2_scale = layer .w2_weight_scale ,
775863 topk_weights = topk_weights ,
776864 topk_ids = topk_ids ,
@@ -779,7 +867,9 @@ def apply(
779867 moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
780868 log2phy = log2phy ,
781869 global_redundant_expert_num = global_redundant_expert_num ,
782- shared_experts = shared_experts )
870+ shared_experts = shared_experts ,
871+ quantized_x_for_share = shared_gate_up ,
872+ dynamic_scale_for_share = shared_dequant_scale )
783873 elif fused_moe_state in [
784874 FusedMoEState .AllGather , FusedMoEState .NaiveMulticast
785875 ]:
@@ -822,6 +912,8 @@ def process_weights_after_loading(self, layer):
822912 torch_npu .npu_format_cast_ (layer .w2_weight , ACL_FORMAT_FRACTAL_NZ )
823913 layer .w13_weight_scale .data = layer .w13_weight_scale .data .view (
824914 layer .w13_weight_scale .data .shape [0 ], - 1 )
915+ layer .w13_weight_scale_fp32 = layer .w13_weight_scale .data .to (
916+ torch .float32 )
825917 layer .w13_weight_offset .data = layer .w13_weight_offset .data .view (
826918 layer .w13_weight_offset .data .shape [0 ], - 1 )
827919 layer .w2_weight_scale .data = layer .w2_weight_scale .data .view (
0 commit comments