Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial fixes for compatibility with HFQuantizer #79

Merged
merged 16 commits into from
Jun 13, 2024
34 changes: 30 additions & 4 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import operator
import os
from copy import deepcopy
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union

from compressed_tensors.base import (
COMPRESSION_CONFIG_NAME,
Expand Down Expand Up @@ -88,6 +88,15 @@ def from_pretrained(
"""
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
return cls.from_compression_config(compression_config)

@classmethod
def from_compression_config(cls, compression_config: Dict[str, Any]):
"""
:param compression_config: compression/quantization config dictionary
found under key "quantization_config" in HF model config
:return: compressor for the extracted configs
"""
if compression_config is None:
return None

Expand All @@ -96,12 +105,16 @@ def from_pretrained(
if sparsity_config is None and quantization_config is None:
return None

if sparsity_config is not None:
if sparsity_config is not None and not isinstance(
sparsity_config, SparsityCompressionConfig
):
format = sparsity_config.get("format")
sparsity_config = SparsityCompressionConfig.load_from_registry(
format, **sparsity_config
)
if quantization_config is not None:
if quantization_config is not None and not isinstance(
quantization_config, QuantizationConfig
):
quantization_config = QuantizationConfig.parse_obj(quantization_config)

return cls(
Expand Down Expand Up @@ -146,16 +159,29 @@ def from_pretrained_model(
def parse_sparsity_config(compression_config: Dict) -> Union[Dict, None]:
if compression_config is None:
return None
if SPARSITY_CONFIG_NAME not in compression_config:
return None
if hasattr(compression_config, SPARSITY_CONFIG_NAME):
# for loaded HFQuantizer config
return getattr(compression_config, SPARSITY_CONFIG_NAME)
return compression_config.get(SPARSITY_CONFIG_NAME, None)
Satrat marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
if compression_config is None:
return None
if QUANTIZATION_CONFIG_NAME not in compression_config:
return None

if hasattr(compression_config, QUANTIZATION_CONFIG_NAME):
# for loaded HFQuantizer config
return getattr(compression_config, QUANTIZATION_CONFIG_NAME)
Satrat marked this conversation as resolved.
Show resolved Hide resolved
quantization_config = deepcopy(compression_config)
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
if len(quantization_config) == 0:
quantization_config = None

return quantization_config
return quantization_config.get(QUANTIZATION_CONFIG_NAME, None)
Satrat marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
Expand Down
20 changes: 12 additions & 8 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,14 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
if target is not None:
# target matched - add layer and scheme to target list
submodule.quantization_scheme = target_to_scheme[target]
if set(config.ignore) - set(ignored_submodules):
_LOGGER.warning(
"Some layers that were to be ignored were "
f"not found in the model: {set(config.ignore) - set(ignored_submodules)}"
)

if config.ignore is not None and ignored_submodules is not None:
if set(config.ignore) - set(ignored_submodules):
_LOGGER.warning(
"Some layers that were to be ignored were "
"not found in the model: "
f"{set(config.ignore) - set(ignored_submodules)}"
)
# apply current quantization status across all targeted layers
apply_quantization_status(model, config.quantization_status)

Expand Down Expand Up @@ -160,9 +163,10 @@ def find_first_name_or_class_match(
# first element of targets that matches the given name
# if no name matches returns first target that matches the class name
# returns None otherwise
return _find_first_match(name, targets) or _find_first_match(
module.__class__.__name__, targets, check_contains
)
if isinstance(targets, Iterable):
return _find_first_match(name, targets) or _find_first_match(
module.__class__.__name__, targets, check_contains
)


def _find_first_match(
Expand Down
41 changes: 36 additions & 5 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from compressed_tensors.quantization.lifecycle.forward import (
wrap_module_forward_quantized,
)
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module, Parameter
Expand Down Expand Up @@ -58,7 +61,12 @@ def initialize_module_for_quantization(
_initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
if scheme.weights is not None:
if hasattr(module, "weight"):
_initialize_scale_zero_point_observer(module, "weight", scheme.weights)
weight_shape = None
if isinstance(module, torch.nn.Linear):
Satrat marked this conversation as resolved.
Show resolved Hide resolved
weight_shape = module.weight.shape
_initialize_scale_zero_point_observer(
module, "weight", scheme.weights, weight_shape=weight_shape
)
else:
_LOGGER.warning(
f"module type {type(module)} targeted for weight quantization but "
Expand All @@ -78,7 +86,10 @@ def initialize_module_for_quantization(


def _initialize_scale_zero_point_observer(
module: Module, base_name: str, quantization_args: QuantizationArgs
module: Module,
base_name: str,
quantization_args: QuantizationArgs,
weight_shape: Optional[torch.Size] = None,
):
# initialize observer module and attach as submodule
observer = quantization_args.get_observer()
Expand All @@ -89,13 +100,33 @@ def _initialize_scale_zero_point_observer(

device = next(module.parameters()).device

# infer expected scale/zero point shape
expected_shape = 1 # per tensor
if (
not hasattr(module, "quantization_status")
or getattr(module, "quantization_status") == QuantizationStatus.FROZEN
):
expected_shape = 0

if base_name == "weight" and weight_shape is not None:
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
# (output_channels, 1)
expected_shape = (weight_shape[0], 1)
elif quantization_args.strategy == QuantizationStrategy.GROUP:
expected_shape = (
weight_shape[0],
weight_shape[1] // quantization_args.group_size,
)

# initializes empty scale and zero point parameters for the module
init_scale = Parameter(
torch.empty(0, dtype=torch.float16, device=device), requires_grad=False
torch.empty(expected_shape, dtype=torch.float16, device=device),
Satrat marked this conversation as resolved.
Show resolved Hide resolved
requires_grad=False,
)
module.register_parameter(f"{base_name}_scale", init_scale)

init_zero_point = Parameter(
torch.empty(0, device=device, dtype=int), requires_grad=False
torch.empty(expected_shape, device=device, dtype=int),
requires_grad=False,
)
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
4 changes: 4 additions & 0 deletions src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def model_post_init(self, __context):
targets=targets_or_scheme,
)

def to_dict(self):
# for compatibility with HFQuantizer
return self.dict()

@staticmethod
def from_pretrained(
model: Module, format: Optional[str] = None
Expand Down
2 changes: 1 addition & 1 deletion tests/test_compressors/test_model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _get_combined_config(s_config, q_config):
combined = {}

if q_config is not None:
combined = deepcopy(q_config)
combined["quantization_config"] = deepcopy(q_config)

if s_config is not None:
combined["sparsity_config"] = s_config
Expand Down
2 changes: 1 addition & 1 deletion tests/test_quantization/lifecycle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def test_maybe_calibrate_or_quantize(create_quantization_scheme, quantization_st
quantization_args = QuantizationArgs(num_bits=num_bits, symmetric=True)
layer = Linear(4, 4)
layer.weight.data *= 100
layer.quantization_status = QuantizationStatus(quantization_status)

initialize_module_for_quantization(layer, quantization_scheme)
layer.quantization_status = QuantizationStatus(quantization_status)

# only calibration updates the scale and zero-point
if layer.quantization_status == QuantizationStatus.INITIALIZED:
Expand Down
Loading