2222
2323from vllm .compilation .decorators import support_torch_compile
2424from vllm .config .utils import getattr_iter
25+ from vllm .distributed import get_dp_group , get_ep_group
2526from vllm .forward_context import ForwardContext , get_forward_context
2627from vllm .model_executor .custom_op import CustomOp
2728from vllm .model_executor .layers .fused_moe import FusedMoE
@@ -40,42 +41,54 @@ class TransformersFusedMoE(FusedMoE):
4041
4142 def __init__ (self , * args , ** kwargs ):
4243 super ().__init__ (* args , ** kwargs )
43- self ._top_k_index : torch .Tensor = None
44+ self ._topk_ids : torch .Tensor = None
4445
4546 def custom_routing_function (hidden_states , gating_output , topk ,
4647 renormalize ):
47- """Return `top_k_weights` from `gating_output` and the
48- `top_k_index` we stored in the layer earlier."""
49- return gating_output , self ._top_k_index
48+ """Return `topk_weights` from `gating_output` and the
49+ `topk_ids` we stored in the layer earlier."""
50+ topk_weights = gating_output
51+ topk_ids = self ._topk_ids
52+ # Handle all gather in expert parallel
53+ if topk_ids .size (0 ) != hidden_states .size (0 ):
54+ dp_metadata = get_forward_context ().dp_metadata
55+ sizes = dp_metadata .get_chunk_sizes_across_dp_rank ()
56+ is_sp = self .is_sequence_parallel
57+ dist_group = get_ep_group () if is_sp else get_dp_group ()
58+ assert sizes [dist_group .rank_in_group ] == topk_ids .shape [0 ]
59+ topk_ids , = dist_group .all_gatherv ([topk_ids ], 0 , sizes )
60+ return topk_weights , topk_ids
5061
5162 self .custom_routing_function = custom_routing_function
5263
53- def forward (self , hidden_states : torch .Tensor , top_k_index : torch .Tensor ,
54- top_k_weights : torch .Tensor , ** kwargs : Any ) -> torch .Tensor :
64+ def forward (self , hidden_states : torch .Tensor , topk_ids : torch .Tensor ,
65+ topk_weights : torch .Tensor , ** kwargs : Any ) -> torch .Tensor :
5566 """In Transformers `experts.forward` will have this signature.
5667
5768 We discard any extra kwargs because we cannot use them here."""
58- return torch .ops .vllm .transformers_moe_forward (hidden_states ,
59- top_k_index ,
60- top_k_weights ,
61- self .layer_name )
69+ return torch .ops .vllm .transformers_moe_forward (
70+ hidden_states ,
71+ topk_ids .to (torch .int32 ),
72+ topk_weights .to (torch .float32 ),
73+ self .layer_name ,
74+ )
6275
6376
6477def transformers_moe_forward (hidden_states : torch .Tensor ,
65- top_k_index : torch .Tensor ,
66- top_k_weights : torch .Tensor ,
78+ topk_ids : torch .Tensor ,
79+ topk_weights : torch .Tensor ,
6780 layer_name : str ) -> torch .Tensor :
68- """Store the `top_k_index ` in the layer and call the actual forward."""
81+ """Store the `topk_ids ` in the layer and call the actual forward."""
6982 forward_context : ForwardContext = get_forward_context ()
7083 self = forward_context .no_compile_layers [layer_name ]
71- self ._top_k_index = top_k_index
84+ self ._topk_ids = topk_ids
7285 # Clone hidden_states because it will be mutated in-place in FusedMoE
73- return self .forward_impl (hidden_states .clone (), top_k_weights )
86+ return self .forward_impl (hidden_states .clone (), topk_weights )
7487
7588
7689def transformers_moe_forward_fake (hidden_states : torch .Tensor ,
77- top_k_index : torch .Tensor ,
78- top_k_weights : torch .Tensor ,
90+ topk_ids : torch .Tensor ,
91+ topk_weights : torch .Tensor ,
7992 layer_name : str ) -> torch .Tensor :
8093 return torch .empty_like (hidden_states )
8194
@@ -96,9 +109,6 @@ def __init__(self, *, vllm_config, prefix=""):
96109 self .check_version ("4.57.0.dev0" , "MoE models support" )
97110 super ().__init__ (vllm_config = vllm_config , prefix = prefix )
98111
99- if self .parallel_config .enable_expert_parallel :
100- raise NotImplementedError (
101- "Transformers backend does not support expert parallel yet." )
102112 if self .parallel_config .enable_eplb :
103113 raise NotImplementedError (
104114 "Transformers backend does not support expert parallel load "
0 commit comments