2020import torch
2121import torch .distributed as dist
2222import torch_npu
23- from vllm .distributed import GroupCoordinator
23+ import torchair as tng # type: ignore
24+ from vllm .distributed import GroupCoordinator , tensor_model_parallel_all_reduce
2425
2526import vllm_ascend .envs as envs_ascend
2627from vllm_ascend .ascend_config import get_ascend_config
@@ -38,7 +39,8 @@ def apply_mlp(hidden_states: torch.Tensor,
3839 w2_scale : torch .Tensor ,
3940 group_list : torch .Tensor ,
4041 dynamic_scale : torch .Tensor = None ,
41- group_list_type : int = 1 ) -> torch .Tensor :
42+ group_list_type : int = 1 ,
43+ ** kwargs ) -> torch .Tensor :
4244 """
4345 apply MLP: gate_up_proj -> swiglu -> down_proj
4446
@@ -72,6 +74,23 @@ def apply_mlp(hidden_states: torch.Tensor,
7274 else :
7375 pertoken_scale = dynamic_scale
7476
77+ shared_experts = kwargs .get ('shared_experts' , None )
78+ if shared_experts :
79+ shared_gate_up = kwargs .get ('shared_gate_up' , None )
80+ shared_dynamic_scale = kwargs .get ('shared_dynamic_scale' , None )
81+ with tng .scope .npu_stream_switch ('cv' ):
82+ tng .scope .npu_wait_tensor (shared_gate_up , hidden_states )
83+ shared_x , shared_dynamic_scale = torch_npu .npu_dequant_swiglu_quant (
84+ x = shared_gate_up ,
85+ weight_scale = shared_experts .gate_up_proj .weight_scale_fp32 ,
86+ activation_scale = shared_dynamic_scale ,
87+ bias = None ,
88+ quant_scale = None ,
89+ quant_offset = None ,
90+ group_index = None ,
91+ activate_left = True ,
92+ quant_mode = 1 )
93+
7594 # gmm1: gate_up_proj
7695 hidden_states = torch_npu .npu_grouped_matmul (
7796 x = [hidden_states ],
@@ -100,25 +119,39 @@ def apply_mlp(hidden_states: torch.Tensor,
100119 group_type = 0 ,
101120 group_list = group_list ,
102121 output_dtype = w2_scale .dtype )[0 ]
122+
123+ if shared_experts :
124+ with tng .scope .npu_stream_switch ('cv' ):
125+ tng .scope .npu_wait_tensor (shared_x , hidden_states )
126+ shared_output = torch_npu .npu_quant_matmul (
127+ shared_x ,
128+ shared_experts .down_proj .weight ,
129+ shared_experts .down_proj .weight_scale ,
130+ pertoken_scale = shared_dynamic_scale ,
131+ output_dtype = torch .bfloat16 ,
132+ )
133+ if shared_experts .down_proj .reduce_results and shared_experts .down_proj .tp_size > 1 :
134+ shared_output = tensor_model_parallel_all_reduce (shared_output )
135+ if shared_experts :
136+ return hidden_states , shared_output
103137 return hidden_states
104138
105139
106- def fused_experts_with_mc2 (
107- hidden_states : torch .Tensor ,
108- w1 : torch .Tensor ,
109- w2 : torch .Tensor ,
110- w1_scale : torch .Tensor ,
111- w2_scale : torch .Tensor ,
112- topk_weights : torch .Tensor ,
113- topk_ids : torch .Tensor ,
114- top_k : int ,
115- expert_map : torch .Tensor = None ,
116- moe_all_to_all_group_name : str = "" ,
117- ) -> torch .Tensor :
140+ def fused_experts_with_mc2 (hidden_states : torch .Tensor ,
141+ w1 : torch .Tensor ,
142+ w2 : torch .Tensor ,
143+ w1_scale : torch .Tensor ,
144+ w2_scale : torch .Tensor ,
145+ topk_weights : torch .Tensor ,
146+ topk_ids : torch .Tensor ,
147+ top_k : int ,
148+ expert_map : torch .Tensor = None ,
149+ moe_all_to_all_group_name : str = "" ,
150+ ** kwargs ) -> torch .Tensor :
118151 global_bs = 0
119152 moe_expert_num = len (expert_map )
120153 # hidden_states = hidden_states.bfloat16()
121- kwargs = {
154+ kwargs_mc2 = {
122155 "x" : hidden_states ,
123156 "expert_ids" : topk_ids ,
124157 "expert_shard_type" : 0 ,
@@ -149,9 +182,27 @@ def fused_experts_with_mc2(
149182 "tp_world_size" : tp_size ,
150183 "tp_rank_id" : tp_rank ,
151184 }
152- kwargs .update (stage1_kwargs )
185+ kwargs_mc2 .update (stage1_kwargs )
186+
187+ shared_experts = kwargs .get ('shared_experts' , None )
188+ if shared_experts :
189+ shared_hidden_states = kwargs .get ('shared_hidden_states' , None )
190+ with tng .scope .npu_stream_switch ('cv' ):
191+ tng .scope .npu_wait_tensor (shared_hidden_states , hidden_states )
192+ shared_x , shared_dynamic_scale = torch_npu .npu_dynamic_quant (
193+ shared_hidden_states )
194+ shared_gate_up = torch_npu .npu_quant_matmul (
195+ shared_x ,
196+ shared_experts .gate_up_proj .weight ,
197+ shared_experts .gate_up_proj .weight_scale ,
198+ output_dtype = torch .int32 ,
199+ )
200+ kwargs .update ({
201+ "shared_gate_up" : shared_gate_up ,
202+ "shared_dynamic_scale" : shared_dynamic_scale ,
203+ })
153204
154- output = torch_npu .npu_moe_distribute_dispatch (** kwargs )
205+ output = torch_npu .npu_moe_distribute_dispatch (** kwargs_mc2 )
155206 # comm_stream.wait_stream(torch.npu.current_stream())
156207 expand_x , dynamic_scale , expand_idx , expert_token_nums , ep_recv_counts = output [
157208 0 :5 ]
@@ -166,10 +217,15 @@ def fused_experts_with_mc2(
166217 w2 ,
167218 w2_scale ,
168219 expert_token_nums ,
169- dynamic_scale = dynamic_scale )
220+ dynamic_scale = dynamic_scale ,
221+ ** kwargs )
222+
223+ multi_stream = isinstance (down_out_list , tuple )
224+ if multi_stream :
225+ down_out_list , shared_output = down_out_list
170226
171227 # moeCombine
172- kwargs = {
228+ kwargs_mc2 = {
173229 "expand_x" : down_out_list ,
174230 "expert_ids" : topk_ids ,
175231 "expand_idx" : expand_idx ,
@@ -193,10 +249,12 @@ def fused_experts_with_mc2(
193249 "tp_world_size" : tp_size ,
194250 "tp_rank_id" : tp_rank ,
195251 }
196- kwargs .update (stage3_kwargs )
252+ kwargs_mc2 .update (stage3_kwargs )
197253
198- hidden_states = torch_npu .npu_moe_distribute_combine (** kwargs )
254+ hidden_states = torch_npu .npu_moe_distribute_combine (** kwargs_mc2 )
199255
256+ if multi_stream :
257+ return hidden_states , shared_output
200258 return hidden_states
201259
202260
@@ -634,7 +692,8 @@ def apply(
634692 topk_ids = topk_ids ,
635693 top_k = top_k ,
636694 expert_map = expert_map ,
637- moe_all_to_all_group_name = self .moe_all_to_all_group_name )
695+ moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
696+ ** kwargs )
638697 elif self .torchair_graph_enabled or self .ep_group .world_size == 1 :
639698 return fused_experts (hidden_states = x ,
640699 w1 = layer .w13_weight ,
0 commit comments