Skip to content

Commit

Permalink
Pretrained Model Reload + SparseGPT Support (#31)
Browse files Browse the repository at this point in the history
* model reload working

* fix for sparseGPT

* docstrings
  • Loading branch information
Sara Adkins authored Apr 23, 2024
1 parent 06200fc commit 67005d7
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/compressed_tensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

SPARSITY_CONFIG_NAME = "sparsity_config"
QUANTIZATION_CONFIG_NAME = "sparseml_quantization_config"
1 change: 1 addition & 0 deletions src/compressed_tensors/compressors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@

from .base import ModelCompressor
from .dense import DenseCompressor
from .helpers import infer_compressor_from_model_config
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
2 changes: 2 additions & 0 deletions src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from compressed_tensors.base import SPARSITY_CONFIG_NAME
from compressed_tensors.config import CompressionConfig
from compressed_tensors.registry import RegistryMixin
from compressed_tensors.utils import get_safetensors_folder
from torch import Tensor
from torch.nn import Module, Parameter
from tqdm import tqdm
Expand Down Expand Up @@ -62,6 +63,7 @@ def overwrite_weights(self, model_path: str, model: Module):
:param model_path: path to compressed weights
:param model: pytorch model to load decompressed weights into
"""
model_path = get_safetensors_folder(model_path)
dense_gen = self.decompress(model_path)
for name, data in tqdm(dense_gen, desc="Decompressing model"):
# loading the decompressed weights into the model
Expand Down
File renamed without changes.
70 changes: 69 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import re
from collections import OrderedDict
from typing import Iterable, Optional
from typing import Dict, Iterable, Optional

from compressed_tensors.quantization.lifecycle.calibration import (
set_module_for_calibration,
Expand All @@ -28,14 +28,60 @@
QuantizationStatus,
)
from compressed_tensors.quantization.utils import iter_named_leaf_modules
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
from torch.nn import Module


__all__ = [
"load_pretrained_quantization",
"apply_quantization_config",
"apply_quantization_status",
]

from compressed_tensors.quantization.utils.helpers import is_module_quantized
from compressed_tensors.utils.safetensors_load import get_quantization_state_dict


def load_pretrained_quantization(model: Module, model_name_or_path: str):
"""
Loads the quantization parameters (scale and zero point) from model_name_or_path to
a model that has already been initialized with a quantization config
:param model: model to load pretrained quantization parameters to
:param model_name_or_path: Hugging Face stub or local folder containing a quantized
model, which is used to load quantization parameters
"""
model_path = get_safetensors_folder(model_name_or_path)
state_dict = get_quantization_state_dict(model_path)

for name, submodule in iter_named_leaf_modules(model):
if not is_module_quantized(submodule):
continue
if submodule.quantization_scheme.weights is not None:
base_name = "weight"
_load_quant_args_from_state_dict(
base_name=base_name,
module_name=name,
module=submodule,
state_dict=state_dict,
)
if submodule.quantization_scheme.input_activations is not None:
base_name = "input"
_load_quant_args_from_state_dict(
base_name=base_name,
module_name=name,
module=submodule,
state_dict=state_dict,
)
if submodule.quantization_scheme.output_activations is not None:
base_name = "output"
_load_quant_args_from_state_dict(
base_name=base_name,
module_name=name,
module=submodule,
state_dict=state_dict,
)


def apply_quantization_config(model: Module, config: QuantizationConfig):
"""
Expand Down Expand Up @@ -103,3 +149,25 @@ def _find_first_match(value: str, targets: Iterable[str]) -> Optional[str]:
elif target == value:
return target
return None


def _load_quant_args_from_state_dict(
base_name: str, module_name: str, module: Module, state_dict: Dict
):
"""
Loads scale and zero point from a state_dict into the specified module
:param base_name: quantization target, one of: weights, input_activations or
output_activations
:param module_name: pytorch module name to look up in state_dict
:module: pytorch module associated with module_name
:state_dict: state_dict to search for matching quantization parameters
"""
scale_name = f"{base_name}_scale"
zp_name = f"{base_name}_zero_point"
device = next(module.parameters()).device

scale = getattr(module, scale_name)
zp = getattr(module, zp_name)
scale.data = state_dict[f"{module_name}.{scale_name}"].to(device)
zp.data = state_dict[f"{module_name}.{zp_name}"].to(device)
5 changes: 5 additions & 0 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def wrapped_forward(self, *args, **kwargs):

if scheme.weights is not None:
# calibrate and (fake) quantize weights when applicable
unquantized_weight = self.weight.data.clone()
self.weight.data = _maybe_calibrate_or_quantize(
module, self.weight, "weight", scheme.weights
)
Expand All @@ -97,6 +98,10 @@ def wrapped_forward(self, *args, **kwargs):
module, output, "output", scheme.output_activations
)

# restore back to unquantized_value
if scheme.weights is not None:
self.weight.data = unquantized_weight

return output

# bind wrapped forward to module class so reference to `self` is correct
Expand Down
17 changes: 17 additions & 0 deletions src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from enum import Enum
from typing import Dict, List, Optional

from compressed_tensors.base import QUANTIZATION_CONFIG_NAME
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils import (
calculate_compression_ratio,
Expand All @@ -24,6 +25,7 @@
)
from pydantic import BaseModel, Field
from torch.nn import Module
from transformers import AutoConfig


__all__ = [
Expand Down Expand Up @@ -98,6 +100,21 @@ class QuantizationConfig(BaseModel):
global_compression_ratio: Optional[float] = None
ignore: Optional[List[str]] = Field(default_factory=list)

@staticmethod
def from_model_config(model_name_or_path) -> "QuantizationConfig":
"""
Given a path to a model config, extract a quantization config if it exists
:param pretrained_model_name_or_path: path to model config on disk or HF hub
:return: instantiated QuantizationConfig if config contains a quant config
"""
config = AutoConfig.from_pretrained(model_name_or_path)
quantization_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
if quantization_config is None:
return None

return QuantizationConfig.parse_obj(quantization_config)

@staticmethod
def from_pretrained(model: Module) -> "QuantizationConfig":
"""
Expand Down
1 change: 0 additions & 1 deletion src/compressed_tensors/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,4 @@
# limitations under the License.
# flake8: noqa

from .helpers import *
from .safetensors_load import *
32 changes: 31 additions & 1 deletion src/compressed_tensors/utils/safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import struct
from typing import Dict, List, Optional

from safetensors import safe_open
from torch import Tensor
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, cached_file


Expand All @@ -28,6 +30,7 @@
"merge_names",
"get_weight_mappings",
"get_nested_weight_mappings",
"get_quantization_state_dict",
]


Expand All @@ -45,7 +48,7 @@ def get_safetensors_folder(
"""
if os.path.exists(pretrained_model_name_or_path):
# argument is a path to a local folder
return pretrained_model_name_or_path
return os.path.abspath(pretrained_model_name_or_path)

safetensors_path = cached_file(
pretrained_model_name_or_path,
Expand Down Expand Up @@ -194,3 +197,30 @@ def get_nested_weight_mappings(
nested_weight_mappings[dense_param][param_name] = weight_mappings[key]

return nested_weight_mappings


def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]:
weight_mappings = get_weight_mappings(model_path)
state_dict = {}
for weight_name, safe_path in weight_mappings.items():
if not _is_quantization_weight(weight_name):
continue
with safe_open(safe_path, framework="pt", device="cpu") as f:
state_dict[weight_name] = f.get_tensor(weight_name)

return state_dict


def _is_quantization_weight(name: str) -> bool:
"""
Checks is a parameter name is associated with a quantization parameter
:param name: parameter name to check
:return: True if parameter name is a quantization parameter, else False
"""
if name.endswith("_scale"):
return True
if name.endswith("zero_point"):
return True

return False

0 comments on commit 67005d7

Please sign in to comment.