Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial fixes for compatibility with HFQuantizer #79

Merged
merged 16 commits into from
Jun 13, 2024
43 changes: 39 additions & 4 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import operator
import os
from copy import deepcopy
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union

from compressed_tensors.base import (
COMPRESSION_CONFIG_NAME,
Expand Down Expand Up @@ -88,20 +88,41 @@ def from_pretrained(
"""
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
return cls.from_compression_config(compression_config)

@classmethod
def from_compression_config(cls, compression_config: Dict[str, Any]):
"""
:param compression_config: compression/quantization config dictionary
found under key "quantization_config" in HF model config
:return: compressor for the extracted configs
"""
if compression_config is None:
return None

try:
from transformers.utils.quantization_config import CompressedTensorsConfig

if isinstance(compression_config, CompressedTensorsConfig):
compression_config = compression_config.to_dict()
except ImportError:
pass

sparsity_config = cls.parse_sparsity_config(compression_config)
quantization_config = cls.parse_quantization_config(compression_config)
if sparsity_config is None and quantization_config is None:
return None

if sparsity_config is not None:
if sparsity_config is not None and not isinstance(
sparsity_config, SparsityCompressionConfig
):
format = sparsity_config.get("format")
sparsity_config = SparsityCompressionConfig.load_from_registry(
format, **sparsity_config
)
if quantization_config is not None:
if quantization_config is not None and not isinstance(
quantization_config, QuantizationConfig
):
quantization_config = QuantizationConfig.parse_obj(quantization_config)

return cls(
Expand Down Expand Up @@ -146,15 +167,29 @@ def from_pretrained_model(
def parse_sparsity_config(compression_config: Dict) -> Union[Dict, None]:
if compression_config is None:
return None
if SPARSITY_CONFIG_NAME not in compression_config:
return None
if hasattr(compression_config, SPARSITY_CONFIG_NAME):
# for loaded HFQuantizer config
return getattr(compression_config, SPARSITY_CONFIG_NAME)

# SparseAutoModel format
return compression_config.get(SPARSITY_CONFIG_NAME, None)

@staticmethod
def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
if compression_config is None:
return None

if hasattr(compression_config, QUANTIZATION_CONFIG_NAME):
# for loaded HFQuantizer config
return getattr(compression_config, QUANTIZATION_CONFIG_NAME)
Satrat marked this conversation as resolved.
Show resolved Hide resolved

# SparseAutoModel format
quantization_config = deepcopy(compression_config)
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
if len(quantization_config) == 0:
quantization_config = None

return quantization_config

def __init__(
Expand Down
28 changes: 18 additions & 10 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,14 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
if target is not None:
# target matched - add layer and scheme to target list
submodule.quantization_scheme = target_to_scheme[target]
if set(config.ignore) - set(ignored_submodules):
_LOGGER.warning(
"Some layers that were to be ignored were "
f"not found in the model: {set(config.ignore) - set(ignored_submodules)}"
)

if config.ignore is not None and ignored_submodules is not None:
if set(config.ignore) - set(ignored_submodules):
_LOGGER.warning(
"Some layers that were to be ignored were "
"not found in the model: "
f"{set(config.ignore) - set(ignored_submodules)}"
)
# apply current quantization status across all targeted layers
apply_quantization_status(model, config.quantization_status)

Expand All @@ -146,7 +149,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):

if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
model.apply(set_module_for_calibration)

if current_status < status >= QuantizationStatus.FROZEN > current_status:
model.apply(freeze_module_quantization)

Expand All @@ -160,9 +162,10 @@ def find_first_name_or_class_match(
# first element of targets that matches the given name
# if no name matches returns first target that matches the class name
# returns None otherwise
return _find_first_match(name, targets) or _find_first_match(
module.__class__.__name__, targets, check_contains
)
if isinstance(targets, Iterable):
return _find_first_match(name, targets) or _find_first_match(
module.__class__.__name__, targets, check_contains
)


def _find_first_match(
Expand Down Expand Up @@ -212,7 +215,12 @@ def _load_quant_args_from_state_dict(
scale = getattr(module, scale_name, None)
zp = getattr(module, zp_name, None)
if scale is not None:
scale.data = state_dict[f"{module_name}.{scale_name}"].to(device)
state_dict_scale = state_dict.get(f"{module_name}.{scale_name}")
if state_dict_scale is not None:
scale.data = state_dict_scale.to(device).to(scale.dtype)
else:
scale.data = scale.data.to(device)

if zp is not None:
zp_from_state = state_dict.get(f"{module_name}.{zp_name}", None)
if zp_from_state is not None: # load the non-zero zero points
Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def dequantize(
:return: dequantized float tensor
"""
if args is None:
if scale.ndim == 0:
if scale.ndim == 0 or scale.ndim == 1:
args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
elif scale.ndim == 2:
if scale.shape[1] == 1:
Expand Down
36 changes: 31 additions & 5 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from compressed_tensors.quantization.lifecycle.forward import (
wrap_module_forward_quantized,
)
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module, Parameter
Expand Down Expand Up @@ -58,7 +61,12 @@ def initialize_module_for_quantization(
_initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
if scheme.weights is not None:
if hasattr(module, "weight"):
_initialize_scale_zero_point_observer(module, "weight", scheme.weights)
weight_shape = None
if isinstance(module, torch.nn.Linear):
Satrat marked this conversation as resolved.
Show resolved Hide resolved
weight_shape = module.weight.shape
_initialize_scale_zero_point_observer(
module, "weight", scheme.weights, weight_shape=weight_shape
)
else:
_LOGGER.warning(
f"module type {type(module)} targeted for weight quantization but "
Expand All @@ -78,7 +86,10 @@ def initialize_module_for_quantization(


def _initialize_scale_zero_point_observer(
module: Module, base_name: str, quantization_args: QuantizationArgs
module: Module,
base_name: str,
quantization_args: QuantizationArgs,
weight_shape: Optional[torch.Size] = None,
):
# initialize observer module and attach as submodule
observer = quantization_args.get_observer()
Expand All @@ -89,13 +100,28 @@ def _initialize_scale_zero_point_observer(

device = next(module.parameters()).device

# infer expected scale/zero point shape
expected_shape = 1 # per tensor

if base_name == "weight" and weight_shape is not None:
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
# (output_channels, 1)
expected_shape = (weight_shape[0], 1)
elif quantization_args.strategy == QuantizationStrategy.GROUP:
expected_shape = (
weight_shape[0],
weight_shape[1] // quantization_args.group_size,
)

# initializes empty scale and zero point parameters for the module
init_scale = Parameter(
torch.empty(0, dtype=torch.float16, device=device), requires_grad=False
torch.empty(expected_shape, dtype=module.weight.dtype, device=device),
requires_grad=False,
)
module.register_parameter(f"{base_name}_scale", init_scale)

init_zero_point = Parameter(
torch.empty(0, device=device, dtype=int), requires_grad=False
torch.empty(expected_shape, device=device, dtype=int),
requires_grad=False,
)
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
4 changes: 4 additions & 0 deletions src/compressed_tensors/quantization/observers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,8 @@ def calculate_qparams(
zero_points = bit_min - torch.round(min_vals / scales)
zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8)

if scales.ndim == 0:
scales = scales.reshape(1)
zero_points = zero_points.reshape(1)

return scales, zero_points
4 changes: 4 additions & 0 deletions src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def model_post_init(self, __context):
targets=targets_or_scheme,
)

def to_dict(self):
# for compatibility with HFQuantizer
return self.dict()

@staticmethod
def from_pretrained(
model: Module, format: Optional[str] = None
Expand Down
2 changes: 1 addition & 1 deletion tests/test_quantization/lifecycle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def test_maybe_calibrate_or_quantize(create_quantization_scheme, quantization_st
quantization_args = QuantizationArgs(num_bits=num_bits, symmetric=True)
layer = Linear(4, 4)
layer.weight.data *= 100
layer.quantization_status = QuantizationStatus(quantization_status)

initialize_module_for_quantization(layer, quantization_scheme)
layer.quantization_status = QuantizationStatus(quantization_status)

# only calibration updates the scale and zero-point
if layer.quantization_status == QuantizationStatus.INITIALIZED:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_quantization/lifecycle/test_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def test_lifecyle(create_quantization_scheme):
assert layer.quantization_status == QuantizationStatus.CALIBRATION

# do a calibration step
assert torch.numel(layer.input_zero_point.data) == 0
assert torch.numel(layer.input_scale) == 0
assert torch.numel(layer.weight_scale) == 0
assert torch.numel(layer.weight_zero_point) == 0
assert torch.numel(layer.input_zero_point.data) == 1
assert torch.numel(layer.input_scale) == 1
assert torch.numel(layer.weight_scale) == 1
assert torch.numel(layer.weight_zero_point) == 1

layer(torch.randn(4, 4))

Expand Down
Loading