3333 dispose_tensor , get_ascend_soc_version )
3434
3535
36+ def apply_mlp_decode (hidden_states : torch .Tensor ,
37+ w1 : torch .Tensor ,
38+ w1_scale : torch .Tensor ,
39+ w2 : torch .Tensor ,
40+ w2_scale : torch .Tensor ,
41+ group_list : torch .Tensor ,
42+ dynamic_scale : torch .Tensor = None ,
43+ group_list_type : int = 1 ) -> torch .Tensor :
44+ """
45+ apply MLP: gate_up_proj -> swiglu -> down_proj
46+ Args:
47+ hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
48+ w1: expert weights1 with shape
49+ (num_experts, hidden_size, intermediate_size * 2)
50+ w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
51+ w2: expert weights2 with shape
52+ (num_experts, intermediate_size, hidden_size)
53+ w2_scale: weights2 scale with shape (num_experts, hidden_size)
54+ group_list: number of tokens for each expert, follow cumsum mode, and
55+ with shape (num_experts).
56+ transpose_weight:
57+ w1: (num_experts, intermediate_size * 2, hidden_size) ->
58+ (num_experts, hidden_size, intermediate_size * 2)
59+ w2: (num_experts, hidden_size, intermediate_size) ->
60+ (num_experts, intermediate_size, hidden_size)
61+ Returns:
62+ hidden_states: output hidden states after MLP.
63+ """
64+
65+ if dynamic_scale is None :
66+ unquantized_hidden_states = hidden_states
67+ hidden_states , pertoken_scale = torch_npu .npu_dynamic_quant (
68+ hidden_states )
69+ # Dispose the original unquantized hidden states
70+ # to save npu memory because they're no longer used.
71+ dispose_tensor (unquantized_hidden_states )
72+ else :
73+ pertoken_scale = dynamic_scale
74+
75+ # gmm1: gate_up_proj
76+ hidden_states = torch_npu .npu_grouped_matmul (
77+ x = [hidden_states ],
78+ weight = [w1 ],
79+ split_item = 3 ,
80+ group_list_type = group_list_type ,
81+ group_type = 0 ,
82+ group_list = group_list ,
83+ output_dtype = torch .int32 )[0 ]
84+
85+ # act_fn: swiglu
86+ hidden_states , swiglu_out_scale = torch_npu .npu_dequant_swiglu_quant (
87+ x = hidden_states ,
88+ weight_scale = w1_scale ,
89+ activation_scale = pertoken_scale ,
90+ bias = None ,
91+ quant_scale = None ,
92+ quant_offset = None ,
93+ group_index = group_list ,
94+ activate_left = True ,
95+ quant_mode = 1 ,
96+ )
97+
98+ # gmm2: down_proj
99+ hidden_states = torch_npu .npu_grouped_matmul (
100+ x = [hidden_states ],
101+ weight = [w2 ],
102+ scale = [w2_scale ],
103+ per_token_scale = [swiglu_out_scale ],
104+ split_item = 2 ,
105+ group_list_type = group_list_type ,
106+ group_type = 0 ,
107+ group_list = group_list ,
108+ output_dtype = w2_scale .dtype )[0 ]
109+ return hidden_states
110+
111+
36112def apply_mlp (hidden_states : torch .Tensor ,
37113 w1 : torch .Tensor ,
38114 w1_scale : torch .Tensor ,
@@ -124,6 +200,8 @@ def fused_experts_with_mc2(
124200 quantized_x_for_share : Optional [Any ] = None ,
125201 dynamic_scale_for_share : Optional [Any ] = None ,
126202 mc2_mask : Optional [torch .Tensor ] = None ,
203+ shared_gate_up : Optional [Any ] = None ,
204+ shared_dequant_scale : Optional [Any ] = None ,
127205) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
128206 assert mc2_mask is not None
129207 if log2phy is not None :
@@ -186,18 +264,19 @@ def fused_experts_with_mc2(
186264
187265 if shared_experts is not None :
188266 with npu_stream_switch ("moe_secondary" , 0 ):
189- npu_wait_tensor (quantized_x_for_share , expand_x )
267+ npu_wait_tensor (shared_gate_up , expand_x )
190268 shared_act_out = shared_experts .act_fn (
191- (quantized_x_for_share , dynamic_scale_for_share ))
269+ (shared_gate_up , shared_dequant_scale ))
192270 shared_act , swiglu_out_scale = shared_act_out [0 ], shared_act_out [1 ]
193271
194- down_out_list = apply_mlp (expand_x ,
195- w1 ,
196- w1_scale ,
197- w2 ,
198- w2_scale ,
199- expert_token_nums ,
200- dynamic_scale = dynamic_scale )
272+ # `expand_x` will be disposed in the `apply_mlp` function
273+ down_out_list = apply_mlp_decode (expand_x ,
274+ w1 ,
275+ w1_scale ,
276+ w2 ,
277+ w2_scale ,
278+ expert_token_nums ,
279+ dynamic_scale = dynamic_scale )
201280
202281 # moeCombine
203282 kwargs_mc2 = {
@@ -745,6 +824,8 @@ def apply(
745824 log2phy : torch .Tensor = None ,
746825 global_redundant_expert_num : int = 0 ,
747826 shared_experts : Optional [Any ] = None ,
827+ quantized_x_for_share : Optional [Any ] = None ,
828+ dynamic_scale_for_share : Optional [Any ] = None ,
748829 ** kwargs ,
749830 ) -> torch .Tensor :
750831 assert router_logits .shape [
@@ -781,15 +862,23 @@ def apply(
781862 e_score_correction_bias = e_score_correction_bias ,
782863 )
783864
865+ fused_moe_state = get_forward_context ().fused_moe_state
866+ shared_gate_up , shared_dequant_scale = None , None
867+ if shared_experts is not None and fused_moe_state == FusedMoEState .MC2 :
868+ with npu_stream_switch ("moe_secondary" , 0 ):
869+ npu_wait_tensor (quantized_x_for_share , router_logits )
870+ share_up_out , _ = shared_experts .gate_up_proj (
871+ (quantized_x_for_share , dynamic_scale_for_share ))
872+ shared_gate_up , shared_dequant_scale = share_up_out [
873+ 0 ], share_up_out [1 ]
874+
784875 # this is a naive implementation for experts load balance so as
785876 # to avoid accumulating too much tokens on a single rank.
786877 # currently it is only activated when doing profile runs.
787878 if enable_force_load_balance :
788879 topk_ids = torch .randint_like (topk_ids , 0 , global_num_experts )
789880
790881 topk_weights = topk_weights .to (x .dtype )
791-
792- fused_moe_state = get_forward_context ().fused_moe_state
793882 if fused_moe_state == FusedMoEState .AllGatherEP :
794883 return fused_experts_with_allgather (
795884 hidden_states = x ,
@@ -806,7 +895,7 @@ def apply(
806895 hidden_states = x ,
807896 w1 = layer .w13_weight ,
808897 w2 = layer .w2_weight ,
809- w1_scale = layer .w13_weight_scale ,
898+ w1_scale = layer .w13_weight_scale_fp32 ,
810899 w2_scale = layer .w2_weight_scale ,
811900 topk_weights = topk_weights ,
812901 topk_ids = topk_ids ,
@@ -817,7 +906,9 @@ def apply(
817906 global_redundant_expert_num = global_redundant_expert_num ,
818907 shared_experts = shared_experts ,
819908 is_torchair = self .torchair_graph_enabled ,
820- mc2_mask = kwargs .get ("mc2_mask" , None ))
909+ mc2_mask = kwargs .get ("mc2_mask" , None ),
910+ shared_gate_up = shared_gate_up ,
911+ shared_dequant_scale = shared_dequant_scale )
821912 elif fused_moe_state in [
822913 FusedMoEState .AllGather , FusedMoEState .NaiveMulticast
823914 ]:
@@ -860,6 +951,8 @@ def process_weights_after_loading(self, layer):
860951 torch_npu .npu_format_cast_ (layer .w2_weight , ACL_FORMAT_FRACTAL_NZ )
861952 layer .w13_weight_scale .data = layer .w13_weight_scale .data .view (
862953 layer .w13_weight_scale .data .shape [0 ], - 1 )
954+ layer .w13_weight_scale_fp32 = layer .w13_weight_scale .data .to (
955+ torch .float32 )
863956 layer .w13_weight_offset .data = layer .w13_weight_offset .data .view (
864957 layer .w13_weight_offset .data .shape [0 ], - 1 )
865958 layer .w2_weight_scale .data = layer .w2_weight_scale .data .view (
0 commit comments