@@ -587,6 +587,12 @@ def __init__(self, moe: MoEConfig = None):
587587 self .global_batch_size = vllm_config .scheduler_config .max_num_seqs
588588 self .local_batch_size = self .global_batch_size // self .ep_size
589589
590+ self .enable_graph_mode = False
591+ additional_config = get_current_vllm_config ().additional_config
592+ if additional_config :
593+ self .enable_graph_mode = additional_config .get (
594+ "enable_graph_mode" , False )
595+
590596 try :
591597 device_group = ep_group .device_group
592598 # TODO: Try local_rank = ep_group.rank_in_group
@@ -664,7 +670,7 @@ def apply(
664670 top_k = top_k ,
665671 expert_map = expert_map ,
666672 moe_all_to_all_group_name = self .moe_all_to_all_group_name )
667- elif get_ep_group ().world_size == 1 :
673+ elif self . enable_graph_mode or get_ep_group ().world_size == 1 :
668674 return fused_experts (hidden_states = x ,
669675 w1 = layer .w13_weight ,
670676 w2 = layer .w2_weight ,
@@ -750,26 +756,20 @@ def __init__(
750756 self .expert_map = None
751757 self .activation = activation
752758
753- if self .ep_size > 1 :
754- # Create a tensor of size num_experts filled with -1
755- self .local_num_experts , self .expert_map = determine_expert_map (
756- self .ep_size ,
757- get_ep_group ().rank_in_group , self .global_num_experts )
758-
759- self .moe_parallel_config .tp_rank = get_etp_group ().rank_in_group
760- self .moe_parallel_config .ep_rank = get_ep_group ().rank_in_group
759+ # Create a tensor of size num_experts filled with -1
760+ self .local_num_experts , self .expert_map = determine_expert_map (
761+ self .ep_size ,
762+ get_ep_group ().rank_in_group , self .global_num_experts )
761763
762- else :
763- # Adjust TP size for DP attention
764- # haven't test its functionality yet, may remove in the future
764+ self .moe_parallel_config .tp_rank = get_etp_group ().rank_in_group
765+ self .moe_parallel_config .ep_rank = get_ep_group ().rank_in_group
765766
766- self .moe_parallel_config .tp_rank = self .tp_size * self .dp_rank
767- self .moe_parallel_config .ep_rank = 0
768- self .moe_parallel_config .tp_size = self .tp_size * self .dp_size
769- self .moe_parallel_config .ep_size = 1
767+ self .enable_graph_mode = False
768+ additional_config = get_current_vllm_config ().additional_config
769+ if additional_config :
770+ self .enable_graph_mode = additional_config .get (
771+ "enable_graph_mode" , False )
770772
771- self .local_num_experts , self .expert_map = (self .global_num_experts ,
772- None )
773773 if self .scoring_func != "softmax" and not self .use_grouped_topk :
774774 raise ValueError ("Only softmax scoring function is supported for "
775775 "non-grouped topk." )
@@ -807,8 +807,15 @@ def __init__(
807807 in ("GPTQMarlinMoEMethod" , "CompressedTensorsWNA16MoEMethod" )):
808808 moe_quant_params ["intermediate_size_full" ] = intermediate_size
809809
810+ self .ep_group = get_ep_group ()
810811 self .quant_method .create_weights (layer = self , ** moe_quant_params )
811812
813+ self .enable_graph_mode = False
814+ additional_config = get_current_vllm_config ().additional_config
815+ if additional_config :
816+ self .enable_graph_mode = additional_config .get (
817+ "enable_graph_mode" , False )
818+
812819 def forward (self ,
813820 hidden_states : torch .Tensor ,
814821 router_logits : torch .Tensor ,
@@ -822,11 +829,28 @@ def forward(self,
822829 else :
823830 real_top_k = self .top_k
824831
825- if VLLM_ENABLE_MC2 and not is_prefill :
826- ...
832+ # MC2 ag/rs broadcast/all_reduce
833+ # prefill_req x x √
834+ # decode_req √ x √
835+ # graph_mode √ √ x
836+ if self .dp_size > 1 :
837+ if VLLM_ENABLE_MC2 and not is_prefill :
838+ ...
839+ elif self .enable_graph_mode :
840+ if USING_LCCL_COM : # type: ignore
841+ hidden_states = get_dp_group ().all_gather (
842+ hidden_states , 0 , False )
843+ router_logits = get_dp_group ().all_gather (
844+ router_logits , 0 , False )
845+ elif self .enable_graph_mode and not is_prefill :
846+ hidden_states = get_dp_group ().all_gather (hidden_states , 0 )
847+ router_logits = get_dp_group ().all_gather (router_logits , 0 )
848+ else :
849+ hidden_states , router_logits = get_ep_group ().dispatch (
850+ hidden_states , router_logits )
827851
828852 # Matrix multiply.
829- final_hidden_states = self .quant_method .apply (
853+ hidden_states = self .quant_method .apply (
830854 layer = self ,
831855 x = hidden_states ,
832856 router_logits = router_logits ,
@@ -843,11 +867,26 @@ def forward(self,
843867 is_prefill = is_prefill ,
844868 enable_force_load_balance = enable_force_load_balance )
845869
846- if VLLM_ENABLE_MC2 and not is_prefill :
847- ...
870+ if self .dp_size > 1 :
871+ if VLLM_ENABLE_MC2 and not is_prefill :
872+ ...
873+ elif self .enable_graph_mode :
874+ if USING_LCCL_COM : # type: ignore
875+ hidden_states = dist ._functional_collectives .reduce_scatter_tensor (
876+ hidden_states ,
877+ "sum" ,
878+ scatter_dim = 0 ,
879+ group = get_dp_group ().device_group )
880+ elif self .enable_graph_mode and not is_prefill :
881+ hidden_states = dist ._functional_collectives .reduce_scatter_tensor (
882+ hidden_states ,
883+ "sum" ,
884+ scatter_dim = 0 ,
885+ group = get_dp_group ().device_group )
886+ else :
887+ hidden_states = get_ep_group ().combine (hidden_states )
848888
849889 if self .reduce_results and (self .tp_size > 1 or self .ep_size > 1 ):
850- final_hidden_states = tensor_model_parallel_all_reduce (
851- final_hidden_states )
890+ hidden_states = tensor_model_parallel_all_reduce (hidden_states )
852891
853- return final_hidden_states
892+ return hidden_states
0 commit comments