diff --git a/tests/quantization/test_register_quantization_config.py b/tests/quantization/test_register_quantization_config.py index 42081a8c68cd..6c541fdbeeae 100644 --- a/tests/quantization/test_register_quantization_config.py +++ b/tests/quantization/test_register_quantization_config.py @@ -53,6 +53,7 @@ class CustomQuantConfig(QuantizationConfig): def __init__(self, num_bits: int = 8) -> None: """Initialize the quantization config.""" + super().__init__() self.num_bits = num_bits def get_name(self) -> QuantizationMethods: diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 262e6799583a..9e1ed3a77179 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -805,7 +805,7 @@ def create_lora_manager( lora_manager_cls: type[LoRAModelManager] = LoRAModelManager, **kwargs) -> LoRAModelManager: """Create a LoRA adapter for a given model.""" - if not hasattr(model, "packed_modules_mapping"): + if not isinstance(model, SupportsLoRA): raise ValueError(f"Model {type(model)} is not supported for LoRA.") lora_manager = lora_manager_cls( model=model, diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 7da44569f408..7a4af74cbeb1 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -111,10 +111,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: # For some models like Qwen2VL, we need to use hf_to_vllm_mapper # to ensure correct loading of lora weights. model = self._adapter_manager.model - hf_to_vllm_mapper = None - if (hasattr(model, "hf_to_vllm_mapper") - and model.hf_to_vllm_mapper is not None): - hf_to_vllm_mapper = model.hf_to_vllm_mapper + hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None) lora = self._lora_model_cls.from_local_checkpoint( lora_path, diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 78c5c75c0651..4a43351260e9 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from vllm.model_executor.layers.quantization import QuantizationMethods + from vllm.model_executor.models.utils import WeightsMapper else: QuantizationMethods = str @@ -149,3 +150,15 @@ def get_quant_method(self, layer: torch.nn.Module, def get_cache_scale(self, name: str) -> Optional[str]: return None + + def apply_vllm_mapper( # noqa: B027 + self, hf_to_vllm_mapper: "WeightsMapper"): + """ + Interface for models to update module names referenced in + quantization configs in order to reflect the vllm model structure + + :param hf_to_vllm_mapper: maps from hf model structure (the assumed + structure of the qconfig) to vllm model structure + """ + # TODO (@kylesayrs): add implementations for all subclasses + pass diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index 9e5ce39ec8f2..aa8eee88a9f9 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -63,6 +63,7 @@ def __init__( # (since we have only one group per output channel) desc_act = False + super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.desc_act = desc_act diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 4f87b2a44f0a..e7f65d13181d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import suppress -from typing import Any, Literal, Optional, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, cast import torch from compressed_tensors.config import (CompressionFormat, @@ -37,6 +37,9 @@ cutlass_fp4_supported) from vllm.platforms import current_platform +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + logger = init_logger(__name__) __all__ = ["CompressedTensorsLinearMethod"] @@ -80,6 +83,18 @@ def get_min_capability(cls) -> int: def get_name(self) -> QuantizationMethods: return "compressed-tensors" + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + self.target_scheme_map = hf_to_vllm_mapper.apply_dict( + self.target_scheme_map) + self.ignore = hf_to_vllm_mapper.apply_list(self.ignore) + self.sparsity_scheme_map = hf_to_vllm_mapper.apply_dict( + self.sparsity_scheme_map) + self.sparsity_ignore_list = hf_to_vllm_mapper.apply_list( + self.sparsity_ignore_list) + if self.kv_cache_scheme is not None: + self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict( + self.kv_cache_scheme) + def get_quant_method( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 93472207fbb8..60df679a74bd 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -from typing import Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -39,6 +39,9 @@ from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + ACTIVATION_SCHEMES = ["static", "dynamic"] logger = init_logger(__name__) @@ -100,6 +103,11 @@ def get_min_capability(cls) -> int: def get_config_filenames(cls) -> list[str]: return [] + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.ignored_layers is not None: + self.ignored_layers = hf_to_vllm_mapper.apply_list( + self.ignored_layers) + @classmethod def from_config(cls, config: dict[str, Any]) -> "Fp8Config": quant_method = cls.get_from_keys(config, ["quant_method"]) diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index 78e0f59fa4be..caeb266d0b93 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -81,6 +81,7 @@ def __init__( # (since we have only one group per output channel) desc_act = False + super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.desc_act = desc_act diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 62667db26b66..18d1c13373df 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -32,6 +32,8 @@ def __init__( group_size: int, lm_head_quantized: bool, ) -> None: + super().__init__() + # Group size for the quantization. self.group_size = group_size self.lm_head_quantized = lm_head_quantized diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e35db5b31dba..a10911b84afc 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -181,6 +181,7 @@ def __init__( exclude_modules: list[str], group_size: int = 16, ) -> None: + super().__init__() self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized if is_checkpoint_nvfp4_serialized: logger.warning( diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index a4e0356c0268..63b2ab6bab06 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -55,6 +55,7 @@ def __init__(self, os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1" logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1") """ + super().__init__() self.torchao_config = torchao_config self.skip_modules = skip_modules or [] diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 79e6fa7b16dc..159e7b1e6b01 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -24,6 +24,7 @@ from vllm.model_executor.models.adapters import (as_classification_model, as_embedding_model, as_reward_model) +from vllm.model_executor.models.interfaces import SupportsQuant from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -294,13 +295,16 @@ def configure_quant_config(quant_config: QuantizationConfig, Note that model attributes are passed by reference to quant_config, enabling them to be updated by model_class.__new__ (ex. chatglm, qwen) + + Once the `SupportsQuant` mixin has been added to all models, this + function can be removed """ - packed_mapping = getattr(model_class, "packed_modules_mapping", None) - if packed_mapping is not None: - # pass packed_modules_mapping by reference to quant_config - quant_config.packed_modules_mapping = packed_mapping - else: - logger.warning( - "The model class %s has not defined `packed_modules_mapping`, " - "this may lead to incorrect mapping of quantized or ignored " - "modules", model_class.__name__) + if not issubclass(model_class, SupportsQuant): + hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None) + packed_mapping = getattr(model_class, "packed_modules_mapping", None) + + # pass mappings by reference to quant_config + if hf_to_vllm_mapper is not None: + quant_config.apply_vllm_mapper(hf_to_vllm_mapper) + if packed_mapping is not None: + quant_config.packed_modules_mapping = packed_mapping diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index ad59fe79edcb..d234632ef1b7 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -18,6 +18,7 @@ if TYPE_CHECKING: from vllm.attention import AttentionMetadata + from vllm.model_executor.models.utils import WeightsMapper from vllm.sequence import IntermediateTensors logger = init_logger(__name__) @@ -566,20 +567,36 @@ def has_step_pooler(model: Union[type[object], object]) -> bool: class SupportsQuant: """The interface required for all models that support quantization.""" - packed_modules_mapping: ClassVar[dict[str, list[str]]] = {} + hf_to_vllm_mapper: ClassVar[Optional["WeightsMapper"]] = None + packed_modules_mapping: ClassVar[Optional[dict[str, list[str]]]] = None quant_config: Optional[QuantizationConfig] = None def __new__(cls, *args, **kwargs) -> Self: instance = super().__new__(cls) + + # find config passed in arguments quant_config = cls._find_quant_config(*args, **kwargs) if quant_config is not None: + + # attach config to model for general use instance.quant_config = quant_config - instance.quant_config.packed_modules_mapping.update( - cls.packed_modules_mapping) + + # apply model mappings to config for proper config-model matching + # NOTE: `TransformersForCausalLM` is not supported due to how this + # class defines `hf_to_vllm_mapper` as a post-init `@property`. + # After this is fixed, get `instance.hf_to_vllm_mapper` directly + if getattr(instance, "hf_to_vllm_mapper", None) is not None: + instance.quant_config.apply_vllm_mapper( + instance.hf_to_vllm_mapper) + if getattr(instance, "packed_modules_mapping", None) is not None: + instance.quant_config.packed_modules_mapping.update( + instance.packed_modules_mapping) + return instance @staticmethod def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]: + """Find quant config passed through model constructor args""" from vllm.config import VllmConfig # avoid circular import args_values = list(args) + list(kwargs.values()) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index ff53a2775e3d..1b64b61a1e5c 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -61,7 +61,7 @@ from vllm.transformers_utils.config import uses_mrope from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) + SupportsMultiModal, SupportsPP, SupportsQuant) from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, apply_rotary_pos_emb_vision) @@ -821,7 +821,8 @@ def _get_mm_fields_config( info=Qwen2_5_VLProcessingInfo, dummy_inputs=Qwen2_5_VLDummyInputsBuilder) class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP): + SupportsLoRA, SupportsPP, + SupportsQuant): # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( @@ -837,7 +838,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config @@ -846,7 +846,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.visual = Qwen2_5_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=self._maybe_ignore_quant_config(quant_config), + quant_config=self._maybe_ignore_quant_config(self.quant_config), prefix=maybe_prefix(prefix, "visual"), ) @@ -859,12 +859,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + def _maybe_ignore_quant_config(self, config: Optional[QuantizationConfig]): # GPTQ configs do not have a list of ignored modules, however AutoGPTQ # seems to avoid vision encoder sections for some models. - if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + if isinstance(config, (GPTQConfig, GPTQMarlinConfig)): return None - return quant_config + return config def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor: diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 2f78d9d4cc06..04ee3a454f9d 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -467,6 +467,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # FIXME(Isotr0py): Don't use any weights mapper for Transformers backend, # this makes thing complicated. We need to remove this mapper after refactor # `TransformersModel` in the future. + # NOTE: `SupportsQuant` can be updated after property decorator is removed @property def hf_to_vllm_mapper(self): prefix_mapper = { diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index aa88f4210160..62deb68035b9 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -4,7 +4,7 @@ import itertools from collections.abc import Iterable, Mapping from dataclasses import dataclass, field -from typing import Callable, Literal, Optional, Protocol, Union, overload +from typing import Any, Callable, Literal, Optional, Protocol, Union, overload import torch import torch.nn as nn @@ -64,6 +64,19 @@ def apply( return ((out_name, data) for name, data in weights if (out_name := self._map_name(name)) is not None) + def apply_list(self, values: list[str]) -> list[str]: + return [ + out_name for name in values + if (out_name := self._map_name(name)) is not None + ] + + def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]: + return { + out_name: value + for name, value in values.items() + if (out_name := self._map_name(name)) is not None + } + class AutoWeightsLoader: """ diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index cbaa34bfc30b..2b20ca2a3ba3 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -58,7 +58,8 @@ def _synced_weight_loader(param, *args, **kwargs): def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]: - parent_map = copy.deepcopy(getattr(model, "packed_modules_mapping", {})) + parent_map = getattr(model, "packed_modules_mapping", None) + parent_map = copy.deepcopy(parent_map) if parent_map is not None else {} # don't infer mapping if the model has defined it explicitly. if parent_map: @@ -66,7 +67,9 @@ def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]: # We only check main components instead of whole model submodules for child in model.children(): - child_map = getattr(child, "packed_modules_mapping", {}) + child_map = getattr(child, "packed_modules_mapping", None) + child_map = copy.deepcopy(child_map) if child_map is not None else {} + if any((k in parent_map and parent_map[k] != v) for k, v in child_map.items()): raise ValueError(