Skip to content

Commit

Permalink
GPTQ - move calibration of quantiztion params to after hessian calibr…
Browse files Browse the repository at this point in the history
…ation (#25)
  • Loading branch information
bfineran authored Jul 22, 2024
1 parent f5e1a10 commit 476d1eb
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
20 changes: 18 additions & 2 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from compressed_tensors.quantization import QuantizationScheme
from compressed_tensors.quantization import (
QuantizationScheme,
disable_quantization,
enable_quantization,
freeze_module_quantization,
)
from loguru import logger
from pydantic import Field
from torch.nn import Module
Expand Down Expand Up @@ -163,7 +168,9 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
if not self.initialized_structure_:
self.on_initialize_structure(state, **kwargs)
if self.quantization_modifier_:
self.quantization_modifier_.initialize(state, **kwargs)
self.quantization_modifier_.initialize(
state, freeze_quantization=False, **kwargs
)
if not self.quantize:
raise ValueError("To use the GPTQModifier, quantization must be enabled.")

Expand All @@ -178,6 +185,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool:

self.initialize_compression(modifiable_model, calibration_dataloader)
self.apply_compression(calibration_dataloader)
state.model.apply(freeze_module_quantization)

return True

Expand Down Expand Up @@ -250,6 +258,11 @@ def apply_compression(
logger.info(
f"Running {class_name} calibration with " f"{len(dataloader)} samples..."
)

# quantization scales and zp are already initialized but we do not
# want to calibrate wrt to these
self.model.apply(disable_quantization)

if not self.sequential_update:
# in non-sequential mode we run one forward batch for all modules
run_calibration_forward(self.model, dataloader, mask_padding=True)
Expand All @@ -271,6 +284,9 @@ def apply_compression(
layer_compressor.revert_layer_wrappers()
torch.cuda.empty_cache()

# re-enable quantization
self.model.apply(enable_quantization)

def _build_quant_modifier(self):
"""
Build a quantization modifier based on the specified config_groups,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ def compress(
elif hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
if quant_scheme.weights is not None:
# fetch latest correct scale and ZP relevant for any changes
# such as activation reordering
from compressed_tensors.quantization import (
update_layer_weight_quant_params,
)

update_layer_weight_quant_params(self.layer)

scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point
from compressed_tensors.quantization import QuantizationStrategy
Expand Down
7 changes: 5 additions & 2 deletions src/llmcompressor/modifiers/quantization/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def on_initialize_structure(self, state: State, **kwargs):
self._apply_modifier_to_model(module)
module.apply(freeze_module_quantization)

def on_initialize(self, state: State, **kwargs) -> bool:
def on_initialize(
self, state: State, freeze_quantization: bool = True, **kwargs
) -> bool:
if self.end and self.end != -1:
raise ValueError(
"end_epoch is disabled for QuantizationModifier and can only be set to"
Expand All @@ -80,7 +82,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
self._check_token_distribution(
module, threshold=kwargs.get("min_tokens_per_module")
)
module.apply(freeze_module_quantization)
if freeze_quantization:
module.apply(freeze_module_quantization)

return True

Expand Down

0 comments on commit 476d1eb

Please sign in to comment.