3333
3434import  vllm_ascend .envs  as  envs_ascend 
3535from  vllm_ascend .distributed .parallel_state  import  get_ep_group , get_etp_group 
36+ from  vllm_ascend .utils  import  npu_stream_switch , npu_wait_tensor 
3637
3738VLLM_ENABLE_MC2 : bool  =  envs_ascend .VLLM_ENABLE_MC2 
3839USING_LCCL_COM : bool  =  envs_ascend .USING_LCCL_COM 
@@ -47,6 +48,7 @@ def fused_experts_with_mc2(
4748    top_k : int ,
4849    expert_map : torch .Tensor  =  None ,
4950    moe_all_to_all_group_name : Optional [str ] =  None ,
51+     shared_experts : Optional [torch .nn .Module ] =  None ,
5052) ->  torch .Tensor :
5153    global_bs  =  0 
5254    moe_expert_num  =  len (expert_map )
@@ -83,11 +85,20 @@ def fused_experts_with_mc2(
8385    }
8486    kwargs .update (stage1_kwargs )
8587
88+     if  shared_experts  is  not None :
89+         with  npu_stream_switch ("expert_secondary" ):
90+             shared_gate_up , _  =  shared_experts .gate_up_proj (hidden_states )
91+ 
8692    output  =  torch_npu .npu_moe_distribute_dispatch (** kwargs )
8793    # comm_stream.wait_stream(torch.npu.current_stream()) 
8894    expand_x , dynamic_scale , expand_idx , expert_token_nums , ep_recv_counts  =  output [
8995        0 :5 ]
9096
97+     if  shared_experts  is  not None :
98+         with  npu_stream_switch ("expert_secondary" ):
99+             npu_wait_tensor (shared_gate_up , expand_x )
100+             shared_act  =  shared_experts .act_fn (shared_gate_up )
101+ 
91102    w1  =  w1 .transpose (1 , 2 )
92103    expert_token_nums  =  torch .cumsum (expert_token_nums ,
93104                                     dim = 0 ,
@@ -118,6 +129,11 @@ def fused_experts_with_mc2(
118129
119130    down_out_list  =  torch .cat (down_out_list , dim = 0 )
120131
132+     if  shared_experts  is  not None :
133+         with  npu_stream_switch ("expert_secondary" ):
134+             npu_wait_tensor (shared_act , down_out_list )
135+             shared_hidden_states , _  =  shared_experts .down_proj (shared_act )
136+ 
121137    # moeCombine 
122138    kwargs  =  {
123139        "expand_x" : down_out_list ,
@@ -145,7 +161,7 @@ def fused_experts_with_mc2(
145161
146162    hidden_states  =  torch_npu .npu_moe_distribute_combine (** kwargs )
147163
148-     return  hidden_states 
164+     return  hidden_states ,  shared_hidden_states   if   shared_experts   is   not   None   else   None 
149165
150166
151167# currently expert parallelism implemented with all2all 
@@ -624,6 +640,7 @@ def apply(
624640        scoring_func : str  =  "softmax" ,
625641        e_score_correction_bias : Optional [torch .Tensor ] =  None ,
626642        is_prefill : bool  =  False ,
643+         shared_experts : Optional [torch .nn .Module ] =  None ,
627644        ** kwargs ,
628645    ):
629646        # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern 
@@ -664,28 +681,36 @@ def apply(
664681                topk_ids = topk_ids ,
665682                top_k = top_k ,
666683                expert_map = expert_map ,
667-                 moe_all_to_all_group_name = self .moe_all_to_all_group_name )
684+                 moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
685+                 shared_experts = shared_experts ,
686+             )
668687        elif  get_ep_group ().world_size  ==  1 :
669-             return  fused_experts (hidden_states = x ,
670-                                  w1 = layer .w13_weight ,
671-                                  w2 = layer .w2_weight ,
672-                                  topk_weights = topk_weights ,
673-                                  topk_ids = topk_ids ,
674-                                  top_k = top_k ,
675-                                  expert_map = expert_map )
688+             router_hidden_states   =  fused_experts (hidden_states = x ,
689+                                                   w1 = layer .w13_weight ,
690+                                                   w2 = layer .w2_weight ,
691+                                                   topk_weights = topk_weights ,
692+                                                   topk_ids = topk_ids ,
693+                                                   top_k = top_k ,
694+                                                   expert_map = expert_map )
676695        else :
677696            # The current implementation of deepseek moe splits hidden_states 
678697            # according to tp_size before they are feed into fused_moe module. 
679698            # Therefore, all2all is needed no matter how dp/tp is set so as to 
680699            # dispatch/combine tokens. 
681-             return  fused_experts_with_all2all (hidden_states = x ,
682-                                               w1 = layer .w13_weight ,
683-                                               w2 = layer .w2_weight ,
684-                                               topk_weights = topk_weights ,
685-                                               topk_ids = topk_ids ,
686-                                               top_k = top_k ,
687-                                               expert_map = expert_map ,
688-                                               ep_group = get_ep_group ())
700+             router_hidden_states  =  fused_experts_with_all2all (
701+                 hidden_states = x ,
702+                 w1 = layer .w13_weight ,
703+                 w2 = layer .w2_weight ,
704+                 topk_weights = topk_weights ,
705+                 topk_ids = topk_ids ,
706+                 top_k = top_k ,
707+                 expert_map = expert_map ,
708+                 ep_group = get_ep_group ())
709+ 
710+         if  shared_experts  is  None :
711+             return  router_hidden_states 
712+         else :
713+             return  router_hidden_states , shared_experts (x )
689714
690715
691716class  AscendFusedMoE (FusedMoE ):
@@ -815,7 +840,8 @@ def forward(self,
815840                router_logits : torch .Tensor ,
816841                is_prefill : bool ,
817842                enable_force_load_balance : bool  =  False ,
818-                 top_k = None ):
843+                 top_k : Optional [int ] =  None ,
844+                 shared_experts : Optional [torch .nn .Module ] =  None ):
819845        assert  self .quant_method  is  not None 
820846
821847        if  top_k :
@@ -842,7 +868,9 @@ def forward(self,
842868            scoring_func = self .scoring_func ,
843869            e_score_correction_bias = self .e_score_correction_bias ,
844870            is_prefill = is_prefill ,
845-             enable_force_load_balance = enable_force_load_balance )
871+             enable_force_load_balance = enable_force_load_balance ,
872+             shared_experts = shared_experts ,
873+         )
846874
847875        if  VLLM_ENABLE_MC2  and  not  is_prefill :
848876            ...
0 commit comments