From 476d1eb941fddaf9b9c2c130a26d553a42263916 Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Mon, 22 Jul 2024 09:31:37 -0400 Subject: [PATCH] GPTQ - move calibration of quantiztion params to after hessian calibration (#25) --- .../modifiers/quantization/gptq/base.py | 20 +++++++++++++++++-- .../quantization/gptq/utils/gptq_wrapper.py | 8 ++++++++ .../quantization/quantization/base.py | 7 +++++-- 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 2ccfb114a..d9d7959db 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,7 +1,12 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch -from compressed_tensors.quantization import QuantizationScheme +from compressed_tensors.quantization import ( + QuantizationScheme, + disable_quantization, + enable_quantization, + freeze_module_quantization, +) from loguru import logger from pydantic import Field from torch.nn import Module @@ -163,7 +168,9 @@ def on_initialize(self, state: "State", **kwargs) -> bool: if not self.initialized_structure_: self.on_initialize_structure(state, **kwargs) if self.quantization_modifier_: - self.quantization_modifier_.initialize(state, **kwargs) + self.quantization_modifier_.initialize( + state, freeze_quantization=False, **kwargs + ) if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") @@ -178,6 +185,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self.initialize_compression(modifiable_model, calibration_dataloader) self.apply_compression(calibration_dataloader) + state.model.apply(freeze_module_quantization) return True @@ -250,6 +258,11 @@ def apply_compression( logger.info( f"Running {class_name} calibration with " f"{len(dataloader)} samples..." ) + + # quantization scales and zp are already initialized but we do not + # want to calibrate wrt to these + self.model.apply(disable_quantization) + if not self.sequential_update: # in non-sequential mode we run one forward batch for all modules run_calibration_forward(self.model, dataloader, mask_padding=True) @@ -271,6 +284,9 @@ def apply_compression( layer_compressor.revert_layer_wrappers() torch.cuda.empty_cache() + # re-enable quantization + self.model.apply(enable_quantization) + def _build_quant_modifier(self): """ Build a quantization modifier based on the specified config_groups, diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py index db2afc64a..5bc3f14f3 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -141,6 +141,14 @@ def compress( elif hasattr(self.layer, "quantization_scheme"): quant_scheme = self.layer.quantization_scheme if quant_scheme.weights is not None: + # fetch latest correct scale and ZP relevant for any changes + # such as activation reordering + from compressed_tensors.quantization import ( + update_layer_weight_quant_params, + ) + + update_layer_weight_quant_params(self.layer) + scale = self.layer.weight_scale zero_point = self.layer.weight_zero_point from compressed_tensors.quantization import QuantizationStrategy diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index 434f6f2d8..b90ec250f 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -61,7 +61,9 @@ def on_initialize_structure(self, state: State, **kwargs): self._apply_modifier_to_model(module) module.apply(freeze_module_quantization) - def on_initialize(self, state: State, **kwargs) -> bool: + def on_initialize( + self, state: State, freeze_quantization: bool = True, **kwargs + ) -> bool: if self.end and self.end != -1: raise ValueError( "end_epoch is disabled for QuantizationModifier and can only be set to" @@ -80,7 +82,8 @@ def on_initialize(self, state: State, **kwargs) -> bool: self._check_token_distribution( module, threshold=kwargs.get("min_tokens_per_module") ) - module.apply(freeze_module_quantization) + if freeze_quantization: + module.apply(freeze_module_quantization) return True