|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 |
|
4 | 4 | from contextlib import suppress |
5 | | -from typing import Any, Literal, Optional, cast |
| 5 | +from typing import TYPE_CHECKING, Any, Literal, Optional, cast |
6 | 6 |
|
7 | 7 | import torch |
8 | 8 | from compressed_tensors.config import (CompressionFormat, |
|
37 | 37 | cutlass_fp4_supported) |
38 | 38 | from vllm.platforms import current_platform |
39 | 39 |
|
| 40 | +if TYPE_CHECKING: |
| 41 | + from vllm.model_executor.models.utils import WeightsMapper |
| 42 | + |
40 | 43 | logger = init_logger(__name__) |
41 | 44 |
|
42 | 45 | __all__ = ["CompressedTensorsLinearMethod"] |
@@ -80,6 +83,18 @@ def get_min_capability(cls) -> int: |
80 | 83 | def get_name(self) -> QuantizationMethods: |
81 | 84 | return "compressed-tensors" |
82 | 85 |
|
| 86 | + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): |
| 87 | + self.target_scheme_map = hf_to_vllm_mapper.apply_dict( |
| 88 | + self.target_scheme_map) |
| 89 | + self.ignore = hf_to_vllm_mapper.apply_list(self.ignore) |
| 90 | + self.sparsity_scheme_map = hf_to_vllm_mapper.apply_dict( |
| 91 | + self.sparsity_scheme_map) |
| 92 | + self.sparsity_ignore_list = hf_to_vllm_mapper.apply_list( |
| 93 | + self.sparsity_ignore_list) |
| 94 | + if self.kv_cache_scheme is not None: |
| 95 | + self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict( |
| 96 | + self.kv_cache_scheme) |
| 97 | + |
83 | 98 | def get_quant_method( |
84 | 99 | self, |
85 | 100 | layer: torch.nn.Module, |
|
0 commit comments