diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index c705a70b93f5..fdfcf89d9ab3 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -32,8 +32,9 @@ If the Transformers model implementation follows all the steps in [writing a cus - All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature) - Any combination of the following vLLM parallelisation schemes: - Data parallel - - Pipeline parallel - Tensor parallel + - Expert parallel + - Pipeline parallel Checking if the modeling backend is Transformers is as simple as: diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py index cb966256b350..f4ae015fdc64 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers_moe.py @@ -22,6 +22,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config.utils import getattr_iter +from vllm.distributed import get_dp_group, get_ep_group from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe import FusedMoE @@ -40,42 +41,54 @@ class TransformersFusedMoE(FusedMoE): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._top_k_index: torch.Tensor = None + self._topk_ids: torch.Tensor = None def custom_routing_function(hidden_states, gating_output, topk, renormalize): - """Return `top_k_weights` from `gating_output` and the - `top_k_index` we stored in the layer earlier.""" - return gating_output, self._top_k_index + """Return `topk_weights` from `gating_output` and the + `topk_ids` we stored in the layer earlier.""" + topk_weights = gating_output + topk_ids = self._topk_ids + # Handle all gather in expert parallel + if topk_ids.size(0) != hidden_states.size(0): + dp_metadata = get_forward_context().dp_metadata + sizes = dp_metadata.get_chunk_sizes_across_dp_rank() + is_sp = self.is_sequence_parallel + dist_group = get_ep_group() if is_sp else get_dp_group() + assert sizes[dist_group.rank_in_group] == topk_ids.shape[0] + topk_ids, = dist_group.all_gatherv([topk_ids], 0, sizes) + return topk_weights, topk_ids self.custom_routing_function = custom_routing_function - def forward(self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, **kwargs: Any) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, topk_ids: torch.Tensor, + topk_weights: torch.Tensor, **kwargs: Any) -> torch.Tensor: """In Transformers `experts.forward` will have this signature. We discard any extra kwargs because we cannot use them here.""" - return torch.ops.vllm.transformers_moe_forward(hidden_states, - top_k_index, - top_k_weights, - self.layer_name) + return torch.ops.vllm.transformers_moe_forward( + hidden_states, + topk_ids.to(torch.int32), + topk_weights.to(torch.float32), + self.layer_name, + ) def transformers_moe_forward(hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, layer_name: str) -> torch.Tensor: - """Store the `top_k_index` in the layer and call the actual forward.""" + """Store the `topk_ids` in the layer and call the actual forward.""" forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self._top_k_index = top_k_index + self._topk_ids = topk_ids # Clone hidden_states because it will be mutated in-place in FusedMoE - return self.forward_impl(hidden_states.clone(), top_k_weights) + return self.forward_impl(hidden_states.clone(), topk_weights) def transformers_moe_forward_fake(hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, layer_name: str) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -96,9 +109,6 @@ def __init__(self, *, vllm_config, prefix=""): self.check_version("4.57.0.dev0", "MoE models support") super().__init__(vllm_config=vllm_config, prefix=prefix) - if self.parallel_config.enable_expert_parallel: - raise NotImplementedError( - "Transformers backend does not support expert parallel yet.") if self.parallel_config.enable_eplb: raise NotImplementedError( "Transformers backend does not support expert parallel load "