Skip to content
1 change: 1 addition & 0 deletions tests/quantization/test_register_quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class CustomQuantConfig(QuantizationConfig):

def __init__(self, num_bits: int = 8) -> None:
"""Initialize the quantization config."""
super().__init__()
self.num_bits = num_bits

def get_name(self) -> QuantizationMethods:
Expand Down
2 changes: 1 addition & 1 deletion vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ def create_lora_manager(
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model."""
if not hasattr(model, "packed_modules_mapping"):
if not isinstance(model, SupportsLoRA):
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
lora_manager = lora_manager_cls(
model=model,
Expand Down
5 changes: 1 addition & 4 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights.
model = self._adapter_manager.model
hf_to_vllm_mapper = None
if (hasattr(model, "hf_to_vllm_mapper")
and model.hf_to_vllm_mapper is not None):
hf_to_vllm_mapper = model.hf_to_vllm_mapper
hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)

lora = self._lora_model_cls.from_local_checkpoint(
lora_path,
Expand Down
13 changes: 13 additions & 0 deletions vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.models.utils import WeightsMapper
else:
QuantizationMethods = str

Expand Down Expand Up @@ -149,3 +150,15 @@ def get_quant_method(self, layer: torch.nn.Module,

def get_cache_scale(self, name: str) -> Optional[str]:
return None

def apply_vllm_mapper( # noqa: B027
self, hf_to_vllm_mapper: "WeightsMapper"):
"""
Interface for models to update module names referenced in
quantization configs in order to reflect the vllm model structure

:param hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
"""
# TODO (@kylesayrs): add implementations for all subclasses
pass
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
# (since we have only one group per output channel)
desc_act = False

super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from contextlib import suppress
from typing import Any, Literal, Optional, cast
from typing import TYPE_CHECKING, Any, Literal, Optional, cast

import torch
from compressed_tensors.config import (CompressionFormat,
Expand Down Expand Up @@ -37,6 +37,9 @@
cutlass_fp4_supported)
from vllm.platforms import current_platform

if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper

logger = init_logger(__name__)

__all__ = ["CompressedTensorsLinearMethod"]
Expand Down Expand Up @@ -80,6 +83,18 @@ def get_min_capability(cls) -> int:
def get_name(self) -> QuantizationMethods:
return "compressed-tensors"

def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
self.target_scheme_map = hf_to_vllm_mapper.apply_dict(
self.target_scheme_map)
self.ignore = hf_to_vllm_mapper.apply_list(self.ignore)
self.sparsity_scheme_map = hf_to_vllm_mapper.apply_dict(
self.sparsity_scheme_map)
self.sparsity_ignore_list = hf_to_vllm_mapper.apply_list(
self.sparsity_ignore_list)
if self.kv_cache_scheme is not None:
self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict(
self.kv_cache_scheme)

def get_quant_method(
self,
layer: torch.nn.Module,
Expand Down
10 changes: 9 additions & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import functools
from typing import Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -39,6 +39,9 @@
from vllm.scalar_type import scalar_types
from vllm.utils import has_deep_gemm

if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper

ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)
Expand Down Expand Up @@ -100,6 +103,11 @@ def get_min_capability(cls) -> int:
def get_config_filenames(cls) -> list[str]:
return []

def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
if self.ignored_layers is not None:
self.ignored_layers = hf_to_vllm_mapper.apply_list(
self.ignored_layers)

@classmethod
def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/gptq_bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
# (since we have only one group per output channel)
desc_act = False

super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def __init__(
group_size: int,
lm_head_quantized: bool,
) -> None:
super().__init__()

# Group size for the quantization.
self.group_size = group_size
self.lm_head_quantized = lm_head_quantized
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def __init__(
exclude_modules: list[str],
group_size: int = 16,
) -> None:
super().__init__()
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
if is_checkpoint_nvfp4_serialized:
logger.warning(
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self,
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
"""
super().__init__()
self.torchao_config = torchao_config
self.skip_modules = skip_modules or []

Expand Down
22 changes: 13 additions & 9 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
as_reward_model)
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.utils import is_pin_memory_available

logger = init_logger(__name__)
Expand Down Expand Up @@ -294,13 +295,16 @@ def configure_quant_config(quant_config: QuantizationConfig,

Note that model attributes are passed by reference to quant_config,
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)

Once the `SupportsQuant` mixin has been added to all models, this
function can be removed
"""
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
if packed_mapping is not None:
# pass packed_modules_mapping by reference to quant_config
quant_config.packed_modules_mapping = packed_mapping
else:
logger.warning(
"The model class %s has not defined `packed_modules_mapping`, "
"this may lead to incorrect mapping of quantized or ignored "
"modules", model_class.__name__)
if not issubclass(model_class, SupportsQuant):
hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
packed_mapping = getattr(model_class, "packed_modules_mapping", None)

# pass mappings by reference to quant_config
if hf_to_vllm_mapper is not None:
quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
if packed_mapping is not None:
quant_config.packed_modules_mapping = packed_mapping
23 changes: 20 additions & 3 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors

logger = init_logger(__name__)
Expand Down Expand Up @@ -566,20 +567,36 @@ def has_step_pooler(model: Union[type[object], object]) -> bool:
class SupportsQuant:
"""The interface required for all models that support quantization."""

packed_modules_mapping: ClassVar[dict[str, list[str]]] = {}
hf_to_vllm_mapper: ClassVar[Optional["WeightsMapper"]] = None
packed_modules_mapping: ClassVar[Optional[dict[str, list[str]]]] = None
quant_config: Optional[QuantizationConfig] = None

def __new__(cls, *args, **kwargs) -> Self:
instance = super().__new__(cls)

# find config passed in arguments
quant_config = cls._find_quant_config(*args, **kwargs)
if quant_config is not None:

# attach config to model for general use
instance.quant_config = quant_config
instance.quant_config.packed_modules_mapping.update(
cls.packed_modules_mapping)

# apply model mappings to config for proper config-model matching
# NOTE: `TransformersForCausalLM` is not supported due to how this
# class defines `hf_to_vllm_mapper` as a post-init `@property`.
# After this is fixed, get `instance.hf_to_vllm_mapper` directly
if getattr(instance, "hf_to_vllm_mapper", None) is not None:
instance.quant_config.apply_vllm_mapper(
instance.hf_to_vllm_mapper)
if getattr(instance, "packed_modules_mapping", None) is not None:
instance.quant_config.packed_modules_mapping.update(
instance.packed_modules_mapping)

return instance

@staticmethod
def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
"""Find quant config passed through model constructor args"""
from vllm.config import VllmConfig # avoid circular import

args_values = list(args) + list(kwargs.values())
Expand Down
14 changes: 7 additions & 7 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from vllm.transformers_utils.config import uses_mrope

from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
SupportsMultiModal, SupportsPP, SupportsQuant)
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
apply_rotary_pos_emb_vision)
Expand Down Expand Up @@ -821,7 +821,8 @@ def _get_mm_fields_config(
info=Qwen2_5_VLProcessingInfo,
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
SupportsLoRA, SupportsPP,
SupportsQuant):

# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
Expand All @@ -837,7 +838,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config

self.config = config
Expand All @@ -846,7 +846,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.visual = Qwen2_5_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
quant_config=self._maybe_ignore_quant_config(self.quant_config),
prefix=maybe_prefix(prefix, "visual"),
)

Expand All @@ -859,12 +859,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
def _maybe_ignore_quant_config(self, config: Optional[QuantizationConfig]):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
if isinstance(config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
return config

def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
# this makes thing complicated. We need to remove this mapper after refactor
# `TransformersModel` in the future.
# NOTE: `SupportsQuant` can be updated after property decorator is removed
@property
def hf_to_vllm_mapper(self):
prefix_mapper = {
Expand Down
15 changes: 14 additions & 1 deletion vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
from typing import Callable, Literal, Optional, Protocol, Union, overload
from typing import Any, Callable, Literal, Optional, Protocol, Union, overload

import torch
import torch.nn as nn
Expand Down Expand Up @@ -64,6 +64,19 @@ def apply(
return ((out_name, data) for name, data in weights
if (out_name := self._map_name(name)) is not None)

def apply_list(self, values: list[str]) -> list[str]:
return [
out_name for name in values
if (out_name := self._map_name(name)) is not None
]

def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]:
return {
out_name: value
for name, value in values.items()
if (out_name := self._map_name(name)) is not None
}


class AutoWeightsLoader:
"""
Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,18 @@ def _synced_weight_loader(param, *args, **kwargs):


def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
parent_map = copy.deepcopy(getattr(model, "packed_modules_mapping", {}))
parent_map = getattr(model, "packed_modules_mapping", None)
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}

# don't infer mapping if the model has defined it explicitly.
if parent_map:
return parent_map

# We only check main components instead of whole model submodules
for child in model.children():
child_map = getattr(child, "packed_modules_mapping", {})
child_map = getattr(child, "packed_modules_mapping", None)
child_map = copy.deepcopy(child_map) if child_map is not None else {}

if any((k in parent_map and parent_map[k] != v)
for k, v in child_map.items()):
raise ValueError(
Expand Down