diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 18c8b7846cb2..5ca2156c08b5 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -17,12 +17,12 @@ These models are what we list in [supported-text-models][supported-text-models] ### Transformers -vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <1% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend". +vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <5% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend". Currently, the Transformers backend works for the following: - Modalities: embedding models, language models and vision-language models* -- Architectures: encoder-only, decoder-only +- Architectures: encoder-only, decoder-only, mixture-of-experts - Attention types: full attention and/or sliding attention _*Vision-language models currently accept only image inputs. Support for video inputs will be added in a future release._ @@ -31,6 +31,7 @@ If the Transformers model implementation follows all the steps in [writing a cus - All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature) - Any combination of the following vLLM parallelisation schemes: + - Data parallel - Pipeline parallel - Tensor parallel diff --git a/tests/models/registry.py b/tests/models/registry.py index 1068f97cb5a8..86a835975227 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -661,6 +661,10 @@ def check_available_online( "TransformersForSequenceClassification": _HfExamplesInfo("papluca/xlm-roberta-base-language-detection", min_transformers_version="4.57.0.dev0"), # noqa: E501 "TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 "TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), + "TransformersMoEForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"), # noqa: E501 + "TransformersMoEForMultimodalLM": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"), # noqa: E501 + "TransformersMoEEmbeddingModel": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501 + "TransformersMoEForSequenceClassification": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501 } _EXAMPLE_MODELS = { diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 733ac8de67a3..bd443575127f 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -66,6 +66,7 @@ def check_implementation( [ ("meta-llama/Llama-3.2-1B-Instruct", "transformers"), ("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE + ("allenai/OLMoE-1B-7B-0924", "transformers"), # MoE ]) # trust_remote_code=True by default def test_models( hf_runner: type[HfRunner], @@ -74,6 +75,14 @@ def test_models( model: str, model_impl: str, ) -> None: + import transformers + from packaging.version import Version + installed = Version(transformers.__version__) + required = Version("4.57.0.dev0") + if model == "allenai/OLMoE-1B-7B-0924" and installed < required: + pytest.skip("MoE models with the Transformers backend require " + f"transformers>={required}, but got {installed}") + check_implementation(hf_runner, vllm_runner, example_prompts, diff --git a/tests/models/utils.py b/tests/models/utils.py index 7e731cffc047..50936114865a 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -430,17 +430,26 @@ def dummy_hf_overrides( update_dict = { "num_layers": num_layers, - "num_experts": num_experts, - "num_experts_per_tok": 2, - "num_local_experts": num_experts, - # Otherwise there will not be any expert layers - "first_k_dense_replace": 0, - # To avoid OOM on DeepSeek-V3 - "n_routed_experts": num_experts, # For Gemma-3n "num_kv_shared_layers": 1, } + class DummyConfig: + hf_text_config = text_config + + # Only set MoE related config when the model has MoE layers. + # Otherwise all models detected as MoE by _get_transformers_backend_cls. + if ModelConfig.get_num_experts(DummyConfig) > 0: + update_dict.update({ + "num_experts": num_experts, + "num_experts_per_tok": 2, + "num_local_experts": num_experts, + # Otherwise there will not be any expert layers + "first_k_dense_replace": 0, + # To avoid OOM on DeepSeek-V3 + "n_routed_experts": num_experts, + }) + # Update num_hidden_layers for non-Longcat architectures if model_arch != "LongcatFlashForCausalLM" \ and model_arch != "LongCatFlashMTPModel": diff --git a/vllm/config/model.py b/vllm/config/model.py index e9d5b58ff2c2..2bf6a1671188 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -20,7 +20,7 @@ MultiModalConfig) from vllm.config.pooler import PoolerConfig from vllm.config.scheduler import RunnerType -from vllm.config.utils import assert_hashable, config +from vllm.config.utils import assert_hashable, config, getattr_iter from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.transformers_utils.config import ( @@ -667,6 +667,8 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": def _get_transformers_backend_cls(self) -> str: """Determine which Transformers backend class will be used if `model_impl` is set to `transformers` or `auto`.""" + prefix = "Transformers" + prefix += "MoE" if self.get_num_experts() > 1 else "" # Check if the architecture we're wrapping has defaults runner = None convert = None @@ -685,15 +687,15 @@ def _get_transformers_backend_cls(self) -> str: # Resolve Transformers backend pooling classes if runner == "pooling": if convert == "embed": - return "TransformersEmbeddingModel" + return prefix + "EmbeddingModel" if convert == "classify": - return "TransformersForSequenceClassification" + return prefix + "ForSequenceClassification" # Resolve Transformers backend generate classes if self.hf_config != self.hf_text_config: # If 'hf_text_config' is the same as 'hf_config'. If not, it is # probably a composite config, i.e. multimodal - return "TransformersForMultimodalLM" - return "TransformersForCausalLM" + return prefix + "ForMultimodalLM" + return prefix + "ForCausalLM" def using_transformers_backend(self) -> bool: """Check if the model is using the Transformers backend class.""" @@ -1025,17 +1027,7 @@ def _verify_bnb_config(self) -> None: self.enforce_eager = True def _verify_with_expert_parallelism(self) -> None: - num_expert_names = [ - "moe_num_experts", # Dbrx - "num_experts", # Jamba - "n_routed_experts", # DeepSeek - "num_local_experts", # Mixtral - ] - num_experts = 0 - for name in num_expert_names: - num_experts = getattr(self.hf_text_config, name, 0) - if num_experts > 0: - break + num_experts = self.get_num_experts() if num_experts < 1: raise ValueError( "Number of experts in the model must be greater than 0 " @@ -1220,6 +1212,21 @@ def get_num_attention_heads(self, parallel_config: ParallelConfig) -> int: num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) return num_heads // parallel_config.tensor_parallel_size + def get_num_experts(self) -> int: + """Returns the number of experts in the model.""" + num_expert_names = [ + "num_experts", # Jamba + "moe_num_experts", # Dbrx + "n_routed_experts", # DeepSeek + "num_local_experts", # Mixtral + ] + num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0) + if isinstance(num_experts, list): + # Ernie VL's remote code uses list[int]... + # The values are always the same so we just take the first one. + return num_experts[0] + return num_experts + def get_layers_start_end_indices( self, parallel_config: ParallelConfig) -> tuple[int, int]: from vllm.distributed.utils import get_pp_indices diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9a7ca7b6d124..3b5ef78b37b0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -960,6 +960,7 @@ def __init__( is_sequence_parallel=False, zero_expert_num: Optional[int] = 0, zero_expert_type: Optional[str] = None, + expert_mapping: Optional[list[tuple[str, str, int, str]]] = None, ): super().__init__() if params_dtype is None: @@ -996,6 +997,9 @@ def __init__( self.zero_expert_num = zero_expert_num self.zero_expert_type = zero_expert_type + # Expert mapping used in self.load_weights + self.expert_mapping = expert_mapping + # Round up hidden size if needed. hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype, quant_config, @@ -1617,6 +1621,33 @@ def weight_loader(self, return False if return_success else None + def load_weights( + self, weights: Iterable[tuple[str, + torch.Tensor]]) -> Iterable[str]: + if (expert_mapping := self.expert_mapping) is None: + raise ValueError("`self.expert_mapping` must be provided to " + "load weights using `self.load_weights`.") + for expert_name, loaded_weight in weights: + qual_name = f"{self.layer_name}.{expert_name}" + for param_name, weight_name, expert_id, shard_id in expert_mapping: + if weight_name not in qual_name: + continue + weight_name = qual_name.replace(weight_name, param_name) + param_name = weight_name.removeprefix(f"{self.layer_name}.") + param = getattr(self, param_name) + success = self.weight_loader( + param=param, + loaded_weight=loaded_weight, + weight_name=weight_name, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + logger.debug("Loaded %s for expert %d into %s", param_name, + expert_id, self.layer_name) + yield param_name + def get_expert_weights(self) -> Iterable[torch.Tensor]: weights = list(self.named_parameters()) assert all(weight.is_contiguous() for _, weight in weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index eb572dc30810..94744fe558bd 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -307,10 +307,14 @@ } _TRANSFORMERS_BACKEND_MODELS = { - "TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501 - "TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501 "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 + "TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"), # noqa: E501 + "TransformersMoEForMultimodalLM": ("transformers_moe", "TransformersMoEForMultimodalLM"), # noqa: E501 + "TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501 + "TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501 + "TransformersMoEForSequenceClassification": ("transformers_pooling", "TransformersMoEForSequenceClassification"), # noqa: E501 + "TransformersMoEEmbeddingModel": ("transformers_pooling", "TransformersMoEEmbeddingModel"), # noqa: E501 } # yapf: enable diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 070c77073bb0..18a0dafd001d 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -22,6 +22,8 @@ import regex as re import torch +import transformers +from packaging.version import Version from torch import nn from transformers import (AutoModel, BatchFeature, PretrainedConfig, PreTrainedModel) @@ -35,6 +37,7 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) @@ -121,10 +124,14 @@ def can_enable_torch_compile(vllm_config: VllmConfig) -> bool: return enable +Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", + "replicate"] + + def replace_linear_class( linear: nn.Linear, - style: Literal["colwise", "rowwise"], - quant_config: QuantizationConfig, + style: Style = "replicate", + quant_config: Optional[QuantizationConfig] = None, *, prefix: str = "", ) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: @@ -132,11 +139,11 @@ def replace_linear_class( Replace nn.Linear with one of vLLM's tensor parallel linear classes. Args: - linear (nn.Linear): `nn.Linear` to be replaced. - style (str): Tensor parallel style of the new linear, e.g. "colwise". - quant_config (QuantConfig): Quantization config for the new linear. + linear: `nn.Linear` to be replaced. + style: Tensor parallel style of the new linear, e.g. "colwise". + quant_config: Quantization config for the new linear. Returns: - Union[ColumnParallelLinear, RowParallelLinear]: The new linear. + The new linear. """ if not isinstance(style, str): @@ -166,6 +173,31 @@ def replace_linear_class( ) +def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: + """Replace a Transformers RMSNorm with vLLM's RMSNorm. + + This method assumes: + - Weight is stored as `weight`. + - Epsilon is stored as `eps` or `variance_epsilon`. + - `with_scale` indicates whether the layer has a weight (Gemma3n only). + - `var_hidden_size` is only ever used for Intern vision encoder in vLLM + and Transformers doesn't appear to have the same concept. + """ + kwargs = { + "hidden_size": hidden_size, + "eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6), + "has_weight": getattr(rms_norm, "with_scale", True) + } + if (weight := getattr(rms_norm, "weight", None)) is not None: + # If weight is a Parameter, get its data tensor + weight = getattr(weight, "data", weight) + kwargs["dtype"] = weight.dtype + else: + # No weight, fall back to weightless RMSNorm + kwargs["has_weight"] = False + return RMSNorm(**kwargs) + + # Copied from `accelerate` @contextmanager def init_on_device_without_buffers(device: torch.device): @@ -463,9 +495,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.ignore_unexpected_suffixes: list[str] = [] """Ignore unexpected weights whose qualname ends with these suffixes.""" - # Skip loading extra bias for GPTQ models. - if self.quant_config and "gptq" in self.quant_config.get_name(): - self.ignore_unexpected_suffixes.append(".bias") + if self.quant_config: + quant_method_name = self.quant_config.get_name() + # Check for unsupported quantization methods. + if quant_method_name == "mxfp4": + raise NotImplementedError("Transformers backend does not " + "support MXFP4 quantization yet.") + # Skip loading extra bias for GPTQ models. + if "gptq" in quant_method_name: + self.ignore_unexpected_suffixes.append(".bias") # Set correct attn and init on "meta" to delay allocating GPU tensors # TODO: @raushan, use the public `model.set_attn_implementation()` @@ -478,8 +516,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): trust_remote_code=self.model_config.trust_remote_code, ) + # Remove layers not on this pipeline parallel rank self.pipeline_parallel() - self.tensor_parallel() + # Substitute remaining layers with vLLM's layers as needed + self.recursive_replace() + # Create attention instances for KV cache allocation + self.attention_instances = self.create_attention_instances() # Input embeddings if not isinstance(self.model.get_input_embeddings(), PPMissingLayer): @@ -494,12 +536,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=self.quant_config, )) - # Attention layers - self.attention_instances = self.create_attention_instances() - # Initialize any parameters that have not had their modules replaced self.init_parameters(self.model) + # Pipeline parallel intermediate tensors self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states"], self.text_config.hidden_size)) @@ -558,56 +598,53 @@ def pipeline_parallel(self): if not self.pp_group.is_last_rank: setattr(self.model, name, PPMissingLayer()) - def tensor_parallel(self): - """ - Apply the model's tensor parallelization plan. - Currently only supports linear layers. + def recursive_replace(self): + """Recursively replace modules in the model as needed. + + Currently, this replaces: + + - `nn.Linear` with vLLM's tensor parallel linear classes + - `*RMSNorm` with vLLM's `RMSNorm` """ - # Look for tp plans in all of the PreTrainedModels found in self.model - is_pretrained_model = lambda m: isinstance(m, PreTrainedModel) - supports_tp_plan = lambda m: m.config.base_model_tp_plan is not None - pretrained_models = filter(is_pretrained_model, self.model.modules()) - models_with_tp_plan = filter(supports_tp_plan, pretrained_models) + tp_plan = self.model.tp_plan - if not any(models_with_tp_plan) and self.tp_size > 1: + if not tp_plan and self.tp_size > 1: tip = get_feature_request_tip(self.model_config.model, self.model_config.trust_remote_code) raise ValueError( f"{type(self.model)} does not support tensor parallel. {tip}") - def _tensor_parallel(module: nn.Module, prefix: str, tp_plan=None): - tp_plan = tp_plan or {} - - # If the current module is a PreTrainedModel, set the tp_plan for - # all of its children - if isinstance(module, PreTrainedModel): - tp_plan = module.config.base_model_tp_plan or {} - tp_plan = { - maybe_prefix(prefix, k): v - for k, v in tp_plan.items() - } - - # Some weight loaders expect linear layers to inherit from vLLM's - # LinearBase class, so we set a default style which causes any - # unspecified linear layers to be replaced with ReplicatedLinear + # Prefix the patterns because we always start from `self.model` + tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()} + + def _recursive_replace(module: nn.Module, prefix: str): for child_name, child_module in module.named_children(): + new_module = child_module qual_name = maybe_prefix(prefix, child_name) if isinstance(child_module, nn.Linear): generator = (p for p in tp_plan if re.match(p, qual_name)) pattern = next(generator, None) + # Some weight loaders expect all linear layers to inherit + # LinearBase, so we set a default style which causes any + # unspecified layers to be replaced with ReplicatedLinear style = tp_plan.get(pattern, "replicate") new_module = replace_linear_class(child_module, style, self.quant_config, prefix=qual_name) + # TODO(hmellor): Enable RMSNorm replacement once we have a way + # to choose RMSNorm vs GemmaRMSNorm + # elif child_module.__class__.__name__.endswith("RMSNorm"): + # new_module = replace_rms_norm_class( + # child_module, self.config.hidden_size) + else: + _recursive_replace(child_module, prefix=qual_name) + + if new_module is not child_module: setattr(module, child_name, new_module) log_replacement(qual_name, child_module, new_module) - else: - _tensor_parallel(child_module, - prefix=qual_name, - tp_plan=tp_plan) - _tensor_parallel(self.model, prefix="model") + _recursive_replace(self.model, prefix="model") def create_attention_instances( self, @@ -657,15 +694,21 @@ def init_parameters(self, self.model: PreTrainedModel = AutoModel.from_config(...) ``` """ - for name, param in module.named_parameters(recurse=False): - if param.device == torch.device("meta"): - new_param = nn.Parameter( - torch.empty_like(param.data, - dtype=dtype or self.model_config.dtype, - device=self.device_config.device)) - setattr(module, name, new_param) - for child in module.children(): - self.init_parameters(child, dtype) + + def _init_parameters(module: nn.Module, dtype: Optional[torch.dtype]): + for name, param in module.named_parameters(recurse=False): + if param.device == torch.device("meta"): + new_param = nn.Parameter( + torch.empty_like( + param.data, + dtype=dtype or self.model_config.dtype, + device=self.device_config.device, + )) + setattr(module, name, new_param) + for child in module.children(): + _init_parameters(child, dtype) + + _init_parameters(module, dtype) def forward( self, @@ -702,8 +745,10 @@ def forward( return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=self.skip_prefixes, @@ -713,6 +758,14 @@ def load_weights(self, weights: Iterable[tuple[str, ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + def check_version(self, min_version: str, feature: str): + installed = Version(transformers.__version__) + required = Version(min_version) + if installed < required: + raise ImportError( + f"Transformers backend requires transformers>={required} " + f"for {feature}, but got {installed}") + @support_torch_compile(enable_if=can_enable_torch_compile) class TransformersForCausalLM(TransformersBase): diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py new file mode 100644 index 000000000000..cb966256b350 --- /dev/null +++ b/vllm/model_executor/models/transformers_moe.py @@ -0,0 +1,268 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrapper around `transformers` MoE models.""" +from typing import Any + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config.utils import getattr_iter +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + +from .transformers import (TransformersBase, TransformersForCausalLM, + TransformersForMultimodalLM, + can_enable_torch_compile, log_replacement) +from .utils import maybe_prefix + + +@CustomOp.register("transformers_fused_moe") +class TransformersFusedMoE(FusedMoE): + """Custom FusedMoE for the Transformers backend.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._top_k_index: torch.Tensor = None + + def custom_routing_function(hidden_states, gating_output, topk, + renormalize): + """Return `top_k_weights` from `gating_output` and the + `top_k_index` we stored in the layer earlier.""" + return gating_output, self._top_k_index + + self.custom_routing_function = custom_routing_function + + def forward(self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """In Transformers `experts.forward` will have this signature. + + We discard any extra kwargs because we cannot use them here.""" + return torch.ops.vllm.transformers_moe_forward(hidden_states, + top_k_index, + top_k_weights, + self.layer_name) + + +def transformers_moe_forward(hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + layer_name: str) -> torch.Tensor: + """Store the `top_k_index` in the layer and call the actual forward.""" + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._top_k_index = top_k_index + # Clone hidden_states because it will be mutated in-place in FusedMoE + return self.forward_impl(hidden_states.clone(), top_k_weights) + + +def transformers_moe_forward_fake(hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + layer_name: str) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="transformers_moe_forward", + op_func=transformers_moe_forward, + mutates_args=["hidden_states"], + fake_impl=transformers_moe_forward_fake, + dispatch_key=current_platform.dispatch_key, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + + +class TransformersMoEBase(TransformersBase): + + def __init__(self, *, vllm_config, prefix=""): + self.check_version("4.57.0.dev0", "MoE models support") + super().__init__(vllm_config=vllm_config, prefix=prefix) + + if self.parallel_config.enable_expert_parallel: + raise NotImplementedError( + "Transformers backend does not support expert parallel yet.") + if self.parallel_config.enable_eplb: + raise NotImplementedError( + "Transformers backend does not support expert parallel load " + "balancing yet.") + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + """ + Params for weights, fp8 weight scales, fp8 activation scales + (param_name, weight_name, expert_id, shard_id) + """ + ckpt_names = [ + # (ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name) + ("gate_proj", "down_proj", "up_proj"), # Most common MoE style + ("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style + ("linear", "linear_1", "linear_v"), # Grok1 style + ] + expert_mapping = [] + for gate_proj, down_proj, up_proj in ckpt_names: + expert_mapping.extend( + FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name=gate_proj, + ckpt_down_proj_name=down_proj, + ckpt_up_proj_name=up_proj, + num_experts=self.model_config.get_num_experts(), + num_redundant_experts=0, # TODO: enable EPLB + )) + return expert_mapping + + def recursive_replace(self): + """Initialize the MoE layers.""" + text_config = self.text_config + + # Positional arguments + num_experts = self.model_config.get_num_experts() + top_k = getattr_iter(text_config, ["num_experts_per_tok", "top_k"], + None) + assert top_k is not None + hidden_size = text_config.hidden_size + intermediate_size = getattr_iter( + text_config, ["moe_intermediate_size", "intermediate_size"], None) + assert intermediate_size is not None + + # If there are shared experts, the results are + # reduced after mlp.forward() not inside FusedMoE + num_experts_shared = getattr_iter(text_config, [ + "num_experts_shared", "n_shared_experts", "moe_num_shared_experts" + ], 0) + reduce_results = num_experts_shared == 0 + + def add_all_reduce(mlp: nn.Module): + """Adds an all-reduce to the output of `mlp.forward()`.""" + + class MLPWithAllReduce(mlp.__class__): + + def forward(self, *args, **kwargs): + output = super().forward(*args, **kwargs) + return self.experts.maybe_all_reduce_tensor_model_parallel( + output) + + mlp.__class__ = MLPWithAllReduce + + # Unused kwargs since we use custom_routing_function: + # - `scoring_func` and `e_score_correction_bias` only used for grouped + # topk routing inside vLLM and are non-trivial to infer + # and hard code `use_grouped_topk=False` + # - `renormalize` passed anyway because it's easy to infer + # - `num_expert_group` and `topk_group` used for inferring expert + # placement strategy in FusedMoE + # - `apply_router_weight_on_input` is already applied in Transformers + renormalize = getattr(text_config, "norm_topk_prob", top_k > 1) + num_expert_group = getattr(text_config, "n_group", None) + topk_group = getattr(text_config, "topk_group", None) + + # MoE activation function + activation = "silu" + wrapped_arch = self.config.architectures[0].lower() + if "gptoss" in wrapped_arch: + activation = "swigluoai" + elif "grok1" in wrapped_arch: + activation = "gelu" + + # Expert mapping for `AutoWeightsLoader` + expert_mapping = self.get_expert_mapping() + + # Configs + parallel_config = self.parallel_config + eplb_config = parallel_config.eplb_config + + # Expert parallel load balancing kwargs + enable_eplb = parallel_config.enable_eplb + num_redundant_experts = eplb_config.num_redundant_experts + + # Recursively fuse MoE layers + def _recursive_replace(module: nn.Module, prefix: str): + for child_name, child_module in module.named_children(): + qual_name = maybe_prefix(prefix, child_name) + if (child_name == "experts" + and isinstance(child_module, nn.ModuleList)): + # Alias for readability + mlp = module + experts = child_module + # Do the experts have biases + has_bias = False + for experts_param_name, _ in experts.named_parameters(): + if "bias" in experts_param_name: + has_bias = True + break + # Double check there are no shared experts + nonlocal reduce_results + if reduce_results: + for mlp_param_name, _ in mlp.named_parameters(): + if "shared_expert" in mlp_param_name: + reduce_results = False + break + # Replace experts module with FusedMoE + fused_experts = TransformersFusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + reduce_results=reduce_results, + renormalize=renormalize, + # Hard coded because topk happens in Transformers + use_grouped_topk=False, + num_expert_group=num_expert_group, + topk_group=topk_group, + quant_config=self.quant_config, + prefix=qual_name, + activation=activation, + enable_eplb=enable_eplb, + num_redundant_experts=num_redundant_experts, + has_bias=has_bias, + expert_mapping=expert_mapping, + ) + mlp.experts = fused_experts + log_replacement(qual_name, experts, fused_experts) + # If results are not all-reduced in FusedMoE, ensure they + # are all-reduced at the end of mlp.forward() if tensor + # parallel or expert parallel is enabled + if not reduce_results and (fused_experts.tp_size > 1 + or fused_experts.ep_size > 1): + add_all_reduce(mlp) + else: + _recursive_replace(child_module, prefix=qual_name) + + _recursive_replace(self.model, prefix="model") + # Continue with the replacement of layers in TransformersBase + super().recursive_replace() + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM): + pass + + +@support_torch_compile( + # set `positions` to last dim to support Qwen-mrope + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }, + enable_if=can_enable_torch_compile) +class TransformersMoEForMultimodalLM(TransformersMoEForCausalLM, + TransformersForMultimodalLM): + pass diff --git a/vllm/model_executor/models/transformers_pooling.py b/vllm/model_executor/models/transformers_pooling.py index 7e262ade156a..27fd40999fe2 100644 --- a/vllm/model_executor/models/transformers_pooling.py +++ b/vllm/model_executor/models/transformers_pooling.py @@ -20,7 +20,7 @@ import torch from transformers import AutoModelForSequenceClassification -from vllm.attention import AttentionType +from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, @@ -29,6 +29,7 @@ from .interfaces_base import VllmModelForPooling from .transformers import TransformersBase, can_enable_torch_compile +from .transformers_moe import TransformersMoEBase from .utils import WeightsMapper @@ -79,7 +80,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.padding_idx = self.text_config.pad_token_id def create_attention_instances( - self, attn_type: AttentionType = AttentionType.DECODER): + self, + attn_type: AttentionType = AttentionType.DECODER + ) -> dict[int, Attention]: # TODO(hmellor): Better way to detect encoder models # In encoder models, the attention layers will have `is_causal=False` is_encoder = lambda m: not getattr(m, "is_causal", True) @@ -90,14 +93,7 @@ def create_attention_instances( # Check minimum transformers version for encoder models support if attn_type == AttentionType.ENCODER_ONLY: - import transformers - from packaging.version import Version - installed = Version(transformers.__version__) - required = Version("4.57.0.dev0") - if installed < required: - raise ValueError( - "Encoder models with the Transformers backend require " - f"transformers>={required}, but got {installed}") + self.check_version("4.57.0.dev0", "encoder models support") return super().create_attention_instances(attn_type) @@ -198,3 +194,15 @@ def forward(self, *args, **kwargs): vllm_config.model_config), ), }) + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersMoEEmbeddingModel(TransformersMoEBase, + TransformersEmbeddingModel): + pass + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersMoEForSequenceClassification( + TransformersMoEBase, TransformersForSequenceClassification): + pass