1818from typing import Any , Callable , Optional
1919
2020import torch
21+ import torch_npu
2122from vllm .config import CompilationLevel , get_current_vllm_config
2223from vllm .distributed import get_dp_group , get_ep_group , get_tp_group
2324from vllm .forward_context import get_forward_context
3132from vllm_ascend .distributed .parallel_state import get_mc2_group
3233from vllm_ascend .ops .fused_moe import apply_mlp , fused_experts_moge
3334from vllm_ascend .ops .layers .experts_selector import select_experts
34- from vllm_ascend .utils import is_310p
35+ from vllm_ascend .utils import is_310p , ACL_FORMAT_FRACTAL_NZ
3536
3637original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod .__init__
3738
@@ -52,7 +53,6 @@ def fused_experts(
5253 w2_scale : Optional [torch .Tensor ] = None ,
5354 w1_scale_bias : torch .Tensor = None ,
5455 w2_scale_bias : torch .Tensor = None ,
55- moe_comm_method : Optional [MoECommMethod ] = None ,
5656 # For TorchAir graph
5757 is_torchair : bool = False ,
5858 # For Cube/Vector parallel
@@ -64,8 +64,8 @@ def fused_experts(
6464 global_redundant_expert_num : int = 0 ,
6565) -> torch .Tensor :
6666 # Check constraints
67- assert hidden_states .shape [1 ] == w1 .shape [2 ], (
68- f"Hidden size mismatch { hidden_states .shape [1 ]} != { w1 .shape [2 ]} " )
67+ assert hidden_states .shape [1 ] == w1 .shape [1 ], (
68+ f"Hidden size mismatch { hidden_states .shape [1 ]} != { w1 .shape [1 ]} " )
6969
7070 assert topk_weights .shape == topk_ids .shape , "topk shape mismatch"
7171 assert hidden_states .is_contiguous (), "Hidden_states must be contiguous"
@@ -74,20 +74,58 @@ def fused_experts(
7474 assert hidden_states .dtype in [
7575 torch .float32 , torch .float16 , torch .bfloat16
7676 ]
77+
78+ moe_comm_method = get_forward_context ().moe_comm_method
7779 assert moe_comm_method is not None , "Missing communication context"
7880
7981 num_experts = w1 .shape [0 ]
8082
81- permuted_hidden_states , expert_tokens , group_list_type = moe_comm_method .permute (
82- hidden_states , topk_ids , topk_weights , expert_map , num_experts )
83- mlp_output = apply_mlp (
84- permuted_hidden_states ,
85- w1 ,
86- w2 ,
87- expert_tokens ,
83+ permuted_hidden_states , expert_tokens , dynamic_scale , group_list_type = moe_comm_method .permute (
84+ hidden_states , topk_ids , topk_weights , expert_map , num_experts , use_int8_w8a8 or use_int4_w4a8 )
85+
86+ if (use_int8_w8a8 or use_int4_w4a8 ) and dynamic_scale is None :
87+ permuted_hidden_states , dynamic_scale = torch_npu .npu_dynamic_quant (
88+ permuted_hidden_states )
89+
90+ gate_up_output = torch_npu .npu_grouped_matmul (
91+ x = [permuted_hidden_states ],
92+ weight = [w1 ],
93+ split_item = 2 ,
8894 group_list_type = group_list_type ,
89- )
90- moe_comm_method .unpermute (mlp_output , hidden_states )
95+ group_type = 0 ,
96+ group_list = expert_tokens ,
97+ output_dtype = torch .int32 if use_int8_w8a8 else None ,
98+ )[0 ]
99+
100+ if use_int8_w8a8 :
101+ activated_output , activated_output_scale = torch_npu .npu_dequant_swiglu_quant (
102+ x = gate_up_output ,
103+ weight_scale = w1_scale .to (torch .float32 ),
104+ activation_scale = dynamic_scale ,
105+ bias = None ,
106+ quant_scale = None ,
107+ quant_offset = None ,
108+ group_index = expert_tokens ,
109+ activate_left = True ,
110+ quant_mode = 1 ,
111+ )
112+ else :
113+ activated_output = torch_npu .npu_swiglu (gate_up_output )
114+ activated_output_scale = None
115+
116+ down_output = torch_npu .npu_grouped_matmul (
117+ x = [activated_output ],
118+ weight = [w2 ],
119+ scale = [w2_scale ] if use_int8_w8a8 else None ,
120+ per_token_scale = [activated_output_scale ] if use_int8_w8a8 else None ,
121+ split_item = 2 ,
122+ group_list_type = group_list_type ,
123+ group_type = 0 ,
124+ group_list = expert_tokens ,
125+ output_dtype = w2_scale .dtype if use_int8_w8a8 else None ,
126+ )[0 ]
127+
128+ moe_comm_method .unpermute (down_output , hidden_states )
91129
92130 return hidden_states
93131
@@ -156,8 +194,6 @@ def forward_oot(
156194 expert_map = expert_map ,
157195 apply_router_weight_on_input = apply_router_weight_on_input )
158196
159- moe_comm_method = get_forward_context ().moe_comm_method
160-
161197 return fused_experts (
162198 hidden_states = x ,
163199 w1 = layer .w13_weight ,
@@ -166,10 +202,26 @@ def forward_oot(
166202 topk_ids = topk_ids ,
167203 global_num_experts = global_num_experts ,
168204 expert_map = expert_map ,
169- moe_comm_method = moe_comm_method ,
170205 )
171206
172207
208+ def process_weights_after_loading (self , layer ):
209+ super (UnquantizedFusedMoEMethod , self ).process_weights_after_loading (layer )
210+ w13_data = self ._maybe_pad_weight (layer .w13_weight .data ).transpose (
211+ 1 , 2 ).contiguous ()
212+ layer .w13_weight = torch .nn .Parameter (w13_data , requires_grad = False )
213+
214+ w2_data = self ._maybe_pad_weight (layer .w2_weight .data ).transpose (
215+ 1 , 2 ).contiguous ()
216+ layer .w2_weight = torch .nn .Parameter (w2_data , requires_grad = False )
217+
218+ if not is_310p ():
219+ layer .w13_weight .data = torch_npu .npu_format_cast (
220+ layer .w13_weight .data , ACL_FORMAT_FRACTAL_NZ )
221+ layer .w2_weight .data = torch_npu .npu_format_cast (
222+ layer .w2_weight .data , ACL_FORMAT_FRACTAL_NZ )
223+
224+
173225class AscendFusedMoE (FusedMoE ):
174226
175227 def __init__ (
@@ -281,4 +333,5 @@ def forward_impl(self, hidden_states: torch.Tensor,
281333
282334
283335UnquantizedFusedMoEMethod .__init__ = unquantized_fused_moe_init_func
336+ UnquantizedFusedMoEMethod .process_weights_after_loading = process_weights_after_loading
284337UnquantizedFusedMoEMethod .forward_oot = forward_oot
0 commit comments