1919
2020import torch
2121import torch_npu
22- from vllm .config import CompilationLevel , get_current_vllm_config
2322from vllm .distributed import get_dp_group , get_ep_group , get_tp_group
2423from vllm .forward_context import get_forward_context
2524from vllm .model_executor .layers .fused_moe .layer import (
2625 FusedMoE , UnquantizedFusedMoEMethod )
2726
28- from vllm_ascend .ascend_config import get_ascend_config
2927from vllm_ascend .distributed .moe_comm_method import (AllGatherCommImpl ,
30- MC2CommImpl ,
31- MoECommMethod )
28+ AlltoAllCommImpl ,
29+ MC2CommImpl )
3230from vllm_ascend .distributed .parallel_state import get_mc2_group
33- from vllm_ascend .ops .fused_moe import apply_mlp , fused_experts_moge
31+ from vllm_ascend .ops .fused_moe import fused_experts_moge
3432from vllm_ascend .ops .layers .experts_selector import select_experts
35- from vllm_ascend .utils import is_310p , ACL_FORMAT_FRACTAL_NZ
33+ from vllm_ascend .ops .moe_dispatcher .token_dispatcher import \
34+ setup_token_dispatchers
35+ from vllm_ascend .utils import ACL_FORMAT_FRACTAL_NZ , is_310p
3636
3737original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod .__init__
3838
@@ -66,26 +66,32 @@ def fused_experts(
6666 # Check constraints
6767 assert hidden_states .shape [1 ] == w1 .shape [1 ], (
6868 f"Hidden size mismatch { hidden_states .shape [1 ]} != { w1 .shape [1 ]} " )
69-
7069 assert topk_weights .shape == topk_ids .shape , "topk shape mismatch"
7170 assert hidden_states .is_contiguous (), "Hidden_states must be contiguous"
7271 assert w1 .stride (- 1 ) == 1 , "Stride of last dimension must be 1"
7372 assert w2 .stride (- 1 ) == 1 , "Stride of last dimension must be 1"
7473 assert hidden_states .dtype in [
7574 torch .float32 , torch .float16 , torch .bfloat16
7675 ]
76+ if (use_int8_w8a8 or use_int4_w4a8 ):
77+ assert w1_scale is not None and w2_scale is not None , \
78+ "INT8 quantization requires weight scales."
79+
80+ w1_scale = w1_scale .to (torch .float32 )
81+ down_scale = [w2_scale ]
82+ down_output_dtype = w2_scale .dtype
83+ else :
84+ down_scale = None
85+ down_output_dtype = None
7786
7887 moe_comm_method = get_forward_context ().moe_comm_method
7988 assert moe_comm_method is not None , "Missing communication context"
8089
8190 num_experts = w1 .shape [0 ]
8291
8392 permuted_hidden_states , expert_tokens , dynamic_scale , group_list_type = moe_comm_method .permute (
84- hidden_states , topk_ids , topk_weights , expert_map , num_experts , use_int8_w8a8 or use_int4_w4a8 )
85-
86- if (use_int8_w8a8 or use_int4_w4a8 ) and dynamic_scale is None :
87- permuted_hidden_states , dynamic_scale = torch_npu .npu_dynamic_quant (
88- permuted_hidden_states )
93+ hidden_states , topk_ids , topk_weights , expert_map , num_experts ,
94+ use_int8_w8a8 or use_int4_w4a8 )
8995
9096 gate_up_output = torch_npu .npu_grouped_matmul (
9197 x = [permuted_hidden_states ],
@@ -97,10 +103,10 @@ def fused_experts(
97103 output_dtype = torch .int32 if use_int8_w8a8 else None ,
98104 )[0 ]
99105
100- if use_int8_w8a8 :
106+ if ( use_int8_w8a8 or use_int4_w4a8 ) :
101107 activated_output , activated_output_scale = torch_npu .npu_dequant_swiglu_quant (
102108 x = gate_up_output ,
103- weight_scale = w1_scale . to ( torch . float32 ) ,
109+ weight_scale = w1_scale ,
104110 activation_scale = dynamic_scale ,
105111 bias = None ,
106112 quant_scale = None ,
@@ -109,42 +115,28 @@ def fused_experts(
109115 activate_left = True ,
110116 quant_mode = 1 ,
111117 )
118+ activated_output_scale = [activated_output_scale ]
112119 else :
113120 activated_output = torch_npu .npu_swiglu (gate_up_output )
114121 activated_output_scale = None
115122
116123 down_output = torch_npu .npu_grouped_matmul (
117124 x = [activated_output ],
118125 weight = [w2 ],
119- scale = [ w2_scale ] if use_int8_w8a8 else None ,
120- per_token_scale = [ activated_output_scale ] if use_int8_w8a8 else None ,
126+ scale = down_scale ,
127+ per_token_scale = activated_output_scale ,
121128 split_item = 2 ,
122129 group_list_type = group_list_type ,
123130 group_type = 0 ,
124131 group_list = expert_tokens ,
125- output_dtype = w2_scale . dtype if use_int8_w8a8 else None ,
132+ output_dtype = down_output_dtype ,
126133 )[0 ]
127134
128135 moe_comm_method .unpermute (down_output , hidden_states )
129136
130137 return hidden_states
131138
132139
133- def unquantized_fused_moe_init_func (self , * args , ** kwargs ):
134- original_unquantized_fused_moe_init_func (self , * args , ** kwargs )
135- vllm_config = get_current_vllm_config ()
136- self .max_num_batched_tokens = vllm_config .scheduler_config .max_num_batched_tokens
137-
138- ascend_config = get_ascend_config ()
139-
140- if ascend_config .torchair_graph_config .enabled :
141- self .use_aclgraph = False
142- else :
143- self .use_aclgraph = (vllm_config .compilation_config .level
144- == CompilationLevel .PIECEWISE
145- and not vllm_config .model_config .enforce_eager )
146-
147-
148140def forward_oot (
149141 self ,
150142 layer : torch .nn .Module ,
@@ -276,12 +268,19 @@ def __init__(
276268 has_bias ,
277269 )
278270
271+ with_quant = quant_config is not None
272+ setup_token_dispatchers (self .moe_config .ep_size ,
273+ top_k = self .top_k ,
274+ num_experts = self .global_num_experts ,
275+ num_local_experts = self .local_num_experts ,
276+ with_quant = with_quant )
277+
279278 self .moe_config .tp_group = get_tp_group ()
280279 self .moe_config .dp_group = get_dp_group ()
281280 self .moe_config .ep_group = get_ep_group ()
282281 self .moe_config .mc2_group = get_mc2_group ()
283282
284- for method in {AllGatherCommImpl , MC2CommImpl }:
283+ for method in {AllGatherCommImpl , AlltoAllCommImpl , MC2CommImpl }:
285284 setattr (
286285 self , method .__name__ .lower (),
287286 method (moe_config = self .moe_config )) # type: ignore[abstract]
@@ -332,6 +331,5 @@ def forward_impl(self, hidden_states: torch.Tensor,
332331 return final_hidden_states
333332
334333
335- UnquantizedFusedMoEMethod .__init__ = unquantized_fused_moe_init_func
336334UnquantizedFusedMoEMethod .process_weights_after_loading = process_weights_after_loading
337335UnquantizedFusedMoEMethod .forward_oot = forward_oot
0 commit comments