3333
3434import vllm_ascend .envs as envs_ascend
3535from vllm_ascend .distributed .parallel_state import get_ep_group , get_etp_group
36+ from vllm_ascend .utils import npu_stream_switch , npu_wait_tensor
3637
3738VLLM_ENABLE_MC2 : bool = envs_ascend .VLLM_ENABLE_MC2
3839USING_LCCL_COM : bool = envs_ascend .USING_LCCL_COM
@@ -47,6 +48,7 @@ def fused_experts_with_mc2(
4748 top_k : int ,
4849 expert_map : torch .Tensor = None ,
4950 moe_all_to_all_group_name : Optional [str ] = None ,
51+ shared_experts : Optional [torch .nn .Module ] = None ,
5052) -> torch .Tensor :
5153 global_bs = 0
5254 moe_expert_num = len (expert_map )
@@ -83,11 +85,20 @@ def fused_experts_with_mc2(
8385 }
8486 kwargs .update (stage1_kwargs )
8587
88+ if shared_experts is not None :
89+ with npu_stream_switch ("expert_secondary" ):
90+ shared_gate_up , _ = shared_experts .gate_up_proj (hidden_states )
91+
8692 output = torch_npu .npu_moe_distribute_dispatch (** kwargs )
8793 # comm_stream.wait_stream(torch.npu.current_stream())
8894 expand_x , dynamic_scale , expand_idx , expert_token_nums , ep_recv_counts = output [
8995 0 :5 ]
9096
97+ if shared_experts is not None :
98+ with npu_stream_switch ("expert_secondary" ):
99+ npu_wait_tensor (shared_gate_up , expand_x )
100+ shared_act = shared_experts .act_fn (shared_gate_up )
101+
91102 w1 = w1 .transpose (1 , 2 )
92103 expert_token_nums = torch .cumsum (expert_token_nums ,
93104 dim = 0 ,
@@ -118,6 +129,11 @@ def fused_experts_with_mc2(
118129
119130 down_out_list = torch .cat (down_out_list , dim = 0 )
120131
132+ if shared_experts is not None :
133+ with npu_stream_switch ("expert_secondary" ):
134+ npu_wait_tensor (shared_act , down_out_list )
135+ shared_hidden_states , _ = shared_experts .down_proj (shared_act )
136+
121137 # moeCombine
122138 kwargs = {
123139 "expand_x" : down_out_list ,
@@ -145,7 +161,7 @@ def fused_experts_with_mc2(
145161
146162 hidden_states = torch_npu .npu_moe_distribute_combine (** kwargs )
147163
148- return hidden_states
164+ return hidden_states , shared_hidden_states if shared_experts is not None else None
149165
150166
151167# currently expert parallelism implemented with all2all
@@ -624,6 +640,7 @@ def apply(
624640 scoring_func : str = "softmax" ,
625641 e_score_correction_bias : Optional [torch .Tensor ] = None ,
626642 is_prefill : bool = False ,
643+ shared_experts : Optional [torch .nn .Module ] = None ,
627644 ** kwargs ,
628645 ):
629646 # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
@@ -664,28 +681,35 @@ def apply(
664681 topk_ids = topk_ids ,
665682 top_k = top_k ,
666683 expert_map = expert_map ,
667- moe_all_to_all_group_name = self .moe_all_to_all_group_name )
684+ moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
685+ shared_experts = shared_experts ,
686+ )
668687 elif get_ep_group ().world_size == 1 :
669- return fused_experts (hidden_states = x ,
670- w1 = layer .w13_weight ,
671- w2 = layer .w2_weight ,
672- topk_weights = topk_weights ,
673- topk_ids = topk_ids ,
674- top_k = top_k ,
675- expert_map = expert_map )
688+ router_hidden_states = fused_experts (hidden_states = x ,
689+ w1 = layer .w13_weight ,
690+ w2 = layer .w2_weight ,
691+ topk_weights = topk_weights ,
692+ topk_ids = topk_ids ,
693+ top_k = top_k ,
694+ expert_map = expert_map )
676695 else :
677696 # The current implementation of deepseek moe splits hidden_states
678697 # according to tp_size before they are feed into fused_moe module.
679698 # Therefore, all2all is needed no matter how dp/tp is set so as to
680699 # dispatch/combine tokens.
681- return fused_experts_with_all2all (hidden_states = x ,
682- w1 = layer .w13_weight ,
683- w2 = layer .w2_weight ,
684- topk_weights = topk_weights ,
685- topk_ids = topk_ids ,
686- top_k = top_k ,
687- expert_map = expert_map ,
688- ep_group = get_ep_group ())
700+ router_hidden_states = fused_experts_with_all2all (hidden_states = x ,
701+ w1 = layer .w13_weight ,
702+ w2 = layer .w2_weight ,
703+ topk_weights = topk_weights ,
704+ topk_ids = topk_ids ,
705+ top_k = top_k ,
706+ expert_map = expert_map ,
707+ ep_group = get_ep_group ())
708+
709+ if shared_experts is None :
710+ return router_hidden_states
711+ else :
712+ return router_hidden_states , shared_experts (x )
689713
690714
691715class AscendFusedMoE (FusedMoE ):
@@ -815,7 +839,8 @@ def forward(self,
815839 router_logits : torch .Tensor ,
816840 is_prefill : bool ,
817841 enable_force_load_balance : bool = False ,
818- top_k = None ):
842+ top_k : Optional [int ] = None ,
843+ shared_experts : Optional [torch .nn .Module ] = None ):
819844 assert self .quant_method is not None
820845
821846 if top_k :
@@ -842,7 +867,9 @@ def forward(self,
842867 scoring_func = self .scoring_func ,
843868 e_score_correction_bias = self .e_score_correction_bias ,
844869 is_prefill = is_prefill ,
845- enable_force_load_balance = enable_force_load_balance )
870+ enable_force_load_balance = enable_force_load_balance ,
871+ shared_experts = shared_experts ,
872+ )
846873
847874 if VLLM_ENABLE_MC2 and not is_prefill :
848875 ...
0 commit comments