@@ -587,6 +587,12 @@ def __init__(self, moe: MoEConfig = None):
587587        self .global_batch_size  =  vllm_config .scheduler_config .max_num_seqs 
588588        self .local_batch_size  =  self .global_batch_size  //  self .ep_size 
589589
590+         self .enable_graph_mode  =  False 
591+         additional_config  =  get_current_vllm_config ().additional_config 
592+         if  additional_config :
593+             self .enable_graph_mode  =  additional_config .get (
594+                 "enable_graph_mode" , False )
595+ 
590596        try :
591597            device_group  =  ep_group .device_group 
592598            # TODO: Try local_rank = ep_group.rank_in_group 
@@ -664,7 +670,7 @@ def apply(
664670                top_k = top_k ,
665671                expert_map = expert_map ,
666672                moe_all_to_all_group_name = self .moe_all_to_all_group_name )
667-         elif  get_ep_group ().world_size  ==  1 :
673+         elif  self . enable_graph_mode   or   get_ep_group ().world_size  ==  1 :
668674            return  fused_experts (hidden_states = x ,
669675                                 w1 = layer .w13_weight ,
670676                                 w2 = layer .w2_weight ,
@@ -750,26 +756,20 @@ def __init__(
750756        self .expert_map  =  None 
751757        self .activation  =  activation 
752758
753-         if  self .ep_size  >  1 :
754-             # Create a tensor of size num_experts filled with -1 
755-             self .local_num_experts , self .expert_map  =  determine_expert_map (
756-                 self .ep_size ,
757-                 get_ep_group ().rank_in_group , self .global_num_experts )
758- 
759-             self .moe_parallel_config .tp_rank  =  get_etp_group ().rank_in_group 
760-             self .moe_parallel_config .ep_rank  =  get_ep_group ().rank_in_group 
759+         # Create a tensor of size num_experts filled with -1 
760+         self .local_num_experts , self .expert_map  =  determine_expert_map (
761+             self .ep_size ,
762+             get_ep_group ().rank_in_group , self .global_num_experts )
761763
762-         else :
763-             # Adjust TP size for DP attention 
764-             # haven't test its functionality yet, may remove in the future 
764+         self .moe_parallel_config .tp_rank  =  get_etp_group ().rank_in_group 
765+         self .moe_parallel_config .ep_rank  =  get_ep_group ().rank_in_group 
765766
766-             self .moe_parallel_config .tp_rank  =  self .tp_size  *  self .dp_rank 
767-             self .moe_parallel_config .ep_rank  =  0 
768-             self .moe_parallel_config .tp_size  =  self .tp_size  *  self .dp_size 
769-             self .moe_parallel_config .ep_size  =  1 
767+         self .enable_graph_mode  =  False 
768+         additional_config  =  get_current_vllm_config ().additional_config 
769+         if  additional_config :
770+             self .enable_graph_mode  =  additional_config .get (
771+                 "enable_graph_mode" , False )
770772
771-             self .local_num_experts , self .expert_map  =  (self .global_num_experts ,
772-                                                        None )
773773        if  self .scoring_func  !=  "softmax"  and  not  self .use_grouped_topk :
774774            raise  ValueError ("Only softmax scoring function is supported for " 
775775                             "non-grouped topk." )
@@ -807,8 +807,15 @@ def __init__(
807807                in  ("GPTQMarlinMoEMethod" , "CompressedTensorsWNA16MoEMethod" )):
808808            moe_quant_params ["intermediate_size_full" ] =  intermediate_size 
809809
810+         self .ep_group  =  get_ep_group ()
810811        self .quant_method .create_weights (layer = self , ** moe_quant_params )
811812
813+         self .enable_graph_mode  =  False 
814+         additional_config  =  get_current_vllm_config ().additional_config 
815+         if  additional_config :
816+             self .enable_graph_mode  =  additional_config .get (
817+                 "enable_graph_mode" , False )
818+ 
812819    def  forward (self ,
813820                hidden_states : torch .Tensor ,
814821                router_logits : torch .Tensor ,
@@ -822,11 +829,28 @@ def forward(self,
822829        else :
823830            real_top_k  =  self .top_k 
824831
825-         if  VLLM_ENABLE_MC2  and  not  is_prefill :
826-             ...
832+         #                MC2   ag/rs  broadcast/all_reduce 
833+         #  prefill_req     x      x            √ 
834+         #  decode_req     √      x            √ 
835+         #  graph_mode     √      √            x 
836+         if  self .dp_size  >  1 :
837+             if  VLLM_ENABLE_MC2  and  not  is_prefill :
838+                 ...
839+             elif  self .enable_graph_mode :
840+                 if  USING_LCCL_COM :  # type: ignore 
841+                     hidden_states  =  get_dp_group ().all_gather (
842+                         hidden_states , 0 , False )
843+                     router_logits  =  get_dp_group ().all_gather (
844+                         router_logits , 0 , False )
845+                 elif  self .enable_graph_mode  and  not  is_prefill :
846+                     hidden_states  =  get_dp_group ().all_gather (hidden_states , 0 )
847+                     router_logits  =  get_dp_group ().all_gather (router_logits , 0 )
848+                 else :
849+                     hidden_states , router_logits  =  get_ep_group ().dispatch (
850+                         hidden_states , router_logits )
827851
828852        # Matrix multiply. 
829-         final_hidden_states  =  self .quant_method .apply (
853+         hidden_states  =  self .quant_method .apply (
830854            layer = self ,
831855            x = hidden_states ,
832856            router_logits = router_logits ,
@@ -843,11 +867,26 @@ def forward(self,
843867            is_prefill = is_prefill ,
844868            enable_force_load_balance = enable_force_load_balance )
845869
846-         if  VLLM_ENABLE_MC2  and  not  is_prefill :
847-             ...
870+         if  self .dp_size  >  1 :
871+             if  VLLM_ENABLE_MC2  and  not  is_prefill :
872+                 ...
873+             elif  self .enable_graph_mode :
874+                 if  USING_LCCL_COM :  # type: ignore 
875+                     hidden_states  =  dist ._functional_collectives .reduce_scatter_tensor (
876+                         hidden_states ,
877+                         "sum" ,
878+                         scatter_dim = 0 ,
879+                         group = get_dp_group ().device_group )
880+                 elif  self .enable_graph_mode  and  not  is_prefill :
881+                     hidden_states  =  dist ._functional_collectives .reduce_scatter_tensor (
882+                         hidden_states ,
883+                         "sum" ,
884+                         scatter_dim = 0 ,
885+                         group = get_dp_group ().device_group )
886+                 else :
887+                     hidden_states  =  get_ep_group ().combine (hidden_states )
848888
849889        if  self .reduce_results  and  (self .tp_size  >  1  or  self .ep_size  >  1 ):
850-             final_hidden_states  =  tensor_model_parallel_all_reduce (
851-                 final_hidden_states )
890+             hidden_states  =  tensor_model_parallel_all_reduce (hidden_states )
852891
853-         return  final_hidden_states 
892+         return  hidden_states 
0 commit comments