Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ If the Transformers model implementation follows all the steps in [writing a cus
- All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature)
- Any combination of the following vLLM parallelisation schemes:
- Data parallel
- Pipeline parallel
- Tensor parallel
- Expert parallel
- Pipeline parallel

Checking if the modeling backend is Transformers is as simple as:

Expand Down
50 changes: 30 additions & 20 deletions vllm/model_executor/models/transformers_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from vllm.compilation.decorators import support_torch_compile
from vllm.config.utils import getattr_iter
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand All @@ -40,42 +41,54 @@ class TransformersFusedMoE(FusedMoE):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._top_k_index: torch.Tensor = None
self._topk_ids: torch.Tensor = None

def custom_routing_function(hidden_states, gating_output, topk,
renormalize):
"""Return `top_k_weights` from `gating_output` and the
`top_k_index` we stored in the layer earlier."""
return gating_output, self._top_k_index
"""Return `topk_weights` from `gating_output` and the
`topk_ids` we stored in the layer earlier."""
topk_weights = gating_output
topk_ids = self._topk_ids
# Handle all gather in expert parallel
if topk_ids.size(0) != hidden_states.size(0):
dp_metadata = get_forward_context().dp_metadata
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
is_sp = self.is_sequence_parallel
dist_group = get_ep_group() if is_sp else get_dp_group()
assert sizes[dist_group.rank_in_group] == topk_ids.shape[0]
topk_ids, = dist_group.all_gatherv([topk_ids], 0, sizes)
return topk_weights, topk_ids

self.custom_routing_function = custom_routing_function

def forward(self, hidden_states: torch.Tensor, top_k_index: torch.Tensor,
top_k_weights: torch.Tensor, **kwargs: Any) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""In Transformers `experts.forward` will have this signature.

We discard any extra kwargs because we cannot use them here."""
return torch.ops.vllm.transformers_moe_forward(hidden_states,
top_k_index,
top_k_weights,
self.layer_name)
return torch.ops.vllm.transformers_moe_forward(
hidden_states,
topk_ids.to(torch.int32),
topk_weights.to(torch.float32),
self.layer_name,
)


def transformers_moe_forward(hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
layer_name: str) -> torch.Tensor:
"""Store the `top_k_index` in the layer and call the actual forward."""
"""Store the `topk_ids` in the layer and call the actual forward."""
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._top_k_index = top_k_index
self._topk_ids = topk_ids
# Clone hidden_states because it will be mutated in-place in FusedMoE
return self.forward_impl(hidden_states.clone(), top_k_weights)
return self.forward_impl(hidden_states.clone(), topk_weights)


def transformers_moe_forward_fake(hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)

Expand All @@ -96,9 +109,6 @@ def __init__(self, *, vllm_config, prefix=""):
self.check_version("4.57.0.dev0", "MoE models support")
super().__init__(vllm_config=vllm_config, prefix=prefix)

if self.parallel_config.enable_expert_parallel:
raise NotImplementedError(
"Transformers backend does not support expert parallel yet.")
if self.parallel_config.enable_eplb:
raise NotImplementedError(
"Transformers backend does not support expert parallel load "
Expand Down