diff --git a/src/compressed_tensors/base.py b/src/compressed_tensors/base.py index f01a055f..d096bc86 100644 --- a/src/compressed_tensors/base.py +++ b/src/compressed_tensors/base.py @@ -13,3 +13,4 @@ # limitations under the License. SPARSITY_CONFIG_NAME = "sparsity_config" +QUANTIZATION_CONFIG_NAME = "sparseml_quantization_config" diff --git a/src/compressed_tensors/compressors/__init__.py b/src/compressed_tensors/compressors/__init__.py index 1c7362eb..50d569e4 100644 --- a/src/compressed_tensors/compressors/__init__.py +++ b/src/compressed_tensors/compressors/__init__.py @@ -16,4 +16,5 @@ from .base import ModelCompressor from .dense import DenseCompressor +from .helpers import infer_compressor_from_model_config from .sparse_bitmask import BitmaskCompressor, BitmaskTensor diff --git a/src/compressed_tensors/compressors/base.py b/src/compressed_tensors/compressors/base.py index 9c205f93..5ef34076 100644 --- a/src/compressed_tensors/compressors/base.py +++ b/src/compressed_tensors/compressors/base.py @@ -18,6 +18,7 @@ from compressed_tensors.base import SPARSITY_CONFIG_NAME from compressed_tensors.config import CompressionConfig from compressed_tensors.registry import RegistryMixin +from compressed_tensors.utils import get_safetensors_folder from torch import Tensor from torch.nn import Module, Parameter from tqdm import tqdm @@ -62,6 +63,7 @@ def overwrite_weights(self, model_path: str, model: Module): :param model_path: path to compressed weights :param model: pytorch model to load decompressed weights into """ + model_path = get_safetensors_folder(model_path) dense_gen = self.decompress(model_path) for name, data in tqdm(dense_gen, desc="Decompressing model"): # loading the decompressed weights into the model diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/compressors/helpers.py similarity index 100% rename from src/compressed_tensors/utils/helpers.py rename to src/compressed_tensors/compressors/helpers.py diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 08cb42f9..ab66cfe8 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -14,7 +14,7 @@ import re from collections import OrderedDict -from typing import Iterable, Optional +from typing import Dict, Iterable, Optional from compressed_tensors.quantization.lifecycle.calibration import ( set_module_for_calibration, @@ -28,14 +28,60 @@ QuantizationStatus, ) from compressed_tensors.quantization.utils import iter_named_leaf_modules +from compressed_tensors.utils.safetensors_load import get_safetensors_folder from torch.nn import Module __all__ = [ + "load_pretrained_quantization", "apply_quantization_config", "apply_quantization_status", ] +from compressed_tensors.quantization.utils.helpers import is_module_quantized +from compressed_tensors.utils.safetensors_load import get_quantization_state_dict + + +def load_pretrained_quantization(model: Module, model_name_or_path: str): + """ + Loads the quantization parameters (scale and zero point) from model_name_or_path to + a model that has already been initialized with a quantization config + + :param model: model to load pretrained quantization parameters to + :param model_name_or_path: Hugging Face stub or local folder containing a quantized + model, which is used to load quantization parameters + """ + model_path = get_safetensors_folder(model_name_or_path) + state_dict = get_quantization_state_dict(model_path) + + for name, submodule in iter_named_leaf_modules(model): + if not is_module_quantized(submodule): + continue + if submodule.quantization_scheme.weights is not None: + base_name = "weight" + _load_quant_args_from_state_dict( + base_name=base_name, + module_name=name, + module=submodule, + state_dict=state_dict, + ) + if submodule.quantization_scheme.input_activations is not None: + base_name = "input" + _load_quant_args_from_state_dict( + base_name=base_name, + module_name=name, + module=submodule, + state_dict=state_dict, + ) + if submodule.quantization_scheme.output_activations is not None: + base_name = "output" + _load_quant_args_from_state_dict( + base_name=base_name, + module_name=name, + module=submodule, + state_dict=state_dict, + ) + def apply_quantization_config(model: Module, config: QuantizationConfig): """ @@ -103,3 +149,25 @@ def _find_first_match(value: str, targets: Iterable[str]) -> Optional[str]: elif target == value: return target return None + + +def _load_quant_args_from_state_dict( + base_name: str, module_name: str, module: Module, state_dict: Dict +): + """ + Loads scale and zero point from a state_dict into the specified module + + :param base_name: quantization target, one of: weights, input_activations or + output_activations + :param module_name: pytorch module name to look up in state_dict + :module: pytorch module associated with module_name + :state_dict: state_dict to search for matching quantization parameters + """ + scale_name = f"{base_name}_scale" + zp_name = f"{base_name}_zero_point" + device = next(module.parameters()).device + + scale = getattr(module, scale_name) + zp = getattr(module, zp_name) + scale.data = state_dict[f"{module_name}.{scale_name}"].to(device) + zp.data = state_dict[f"{module_name}.{zp_name}"].to(device) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 48b93e02..e2a198b4 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -82,6 +82,7 @@ def wrapped_forward(self, *args, **kwargs): if scheme.weights is not None: # calibrate and (fake) quantize weights when applicable + unquantized_weight = self.weight.data.clone() self.weight.data = _maybe_calibrate_or_quantize( module, self.weight, "weight", scheme.weights ) @@ -97,6 +98,10 @@ def wrapped_forward(self, *args, **kwargs): module, output, "output", scheme.output_activations ) + # restore back to unquantized_value + if scheme.weights is not None: + self.weight.data = unquantized_weight + return output # bind wrapped forward to module class so reference to `self` is correct diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index a62a79bd..a894b4c2 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -15,6 +15,7 @@ from enum import Enum from typing import Dict, List, Optional +from compressed_tensors.base import QUANTIZATION_CONFIG_NAME from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import ( calculate_compression_ratio, @@ -24,6 +25,7 @@ ) from pydantic import BaseModel, Field from torch.nn import Module +from transformers import AutoConfig __all__ = [ @@ -98,6 +100,21 @@ class QuantizationConfig(BaseModel): global_compression_ratio: Optional[float] = None ignore: Optional[List[str]] = Field(default_factory=list) + @staticmethod + def from_model_config(model_name_or_path) -> "QuantizationConfig": + """ + Given a path to a model config, extract a quantization config if it exists + + :param pretrained_model_name_or_path: path to model config on disk or HF hub + :return: instantiated QuantizationConfig if config contains a quant config + """ + config = AutoConfig.from_pretrained(model_name_or_path) + quantization_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) + if quantization_config is None: + return None + + return QuantizationConfig.parse_obj(quantization_config) + @staticmethod def from_pretrained(model: Module) -> "QuantizationConfig": """ diff --git a/src/compressed_tensors/utils/__init__.py b/src/compressed_tensors/utils/__init__.py index e9e78d44..5bc0fec2 100644 --- a/src/compressed_tensors/utils/__init__.py +++ b/src/compressed_tensors/utils/__init__.py @@ -13,5 +13,4 @@ # limitations under the License. # flake8: noqa -from .helpers import * from .safetensors_load import * diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index 4d71482a..7a9973dc 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -18,6 +18,8 @@ import struct from typing import Dict, List, Optional +from safetensors import safe_open +from torch import Tensor from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, cached_file @@ -28,6 +30,7 @@ "merge_names", "get_weight_mappings", "get_nested_weight_mappings", + "get_quantization_state_dict", ] @@ -45,7 +48,7 @@ def get_safetensors_folder( """ if os.path.exists(pretrained_model_name_or_path): # argument is a path to a local folder - return pretrained_model_name_or_path + return os.path.abspath(pretrained_model_name_or_path) safetensors_path = cached_file( pretrained_model_name_or_path, @@ -194,3 +197,30 @@ def get_nested_weight_mappings( nested_weight_mappings[dense_param][param_name] = weight_mappings[key] return nested_weight_mappings + + +def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]: + weight_mappings = get_weight_mappings(model_path) + state_dict = {} + for weight_name, safe_path in weight_mappings.items(): + if not _is_quantization_weight(weight_name): + continue + with safe_open(safe_path, framework="pt", device="cpu") as f: + state_dict[weight_name] = f.get_tensor(weight_name) + + return state_dict + + +def _is_quantization_weight(name: str) -> bool: + """ + Checks is a parameter name is associated with a quantization parameter + + :param name: parameter name to check + :return: True if parameter name is a quantization parameter, else False + """ + if name.endswith("_scale"): + return True + if name.endswith("zero_point"): + return True + + return False