diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index bdf27f620..e6bf7b319 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -4,6 +4,7 @@ """ import ast +import contextlib import errno import fnmatch import glob @@ -23,6 +24,7 @@ import numpy import torch +from compressed_tensors.quantization import disable_quantization, enable_quantization from loguru import logger __all__ = [ @@ -61,6 +63,8 @@ "import_from_path", "getattr_chain", "DisableKVCache", + "DisableQuantization", + "calibration_forward_context", ] @@ -1080,3 +1084,32 @@ def __enter__(self): def __exit__(self, _exc_type, _exc_val, _exc_tb): self.config.use_cache = self.restore_value + + +@contextlib.contextmanager +def DisableQuantization(model: torch.nn.Module): + """ + Disable quantization from QuantizationModifier + """ + model.apply(disable_quantization) + yield + model.apply(enable_quantization) + + +@contextlib.contextmanager +def calibration_forward_context(model: torch.nn.Module): + """ + Context in which all calibration forward passes should occur. + + - Remove gradient calculations + - Disable the KV cache + - Disable quantization from QuantizationModifier + """ + model.eval() + + with ( + torch.no_grad(), + DisableKVCache(model), + DisableQuantization(model), + ): + yield diff --git a/src/llmcompressor/utils/metric_logging.py b/src/llmcompressor/utils/metric_logging.py index 0b45a4670..0f466555f 100644 --- a/src/llmcompressor/utils/metric_logging.py +++ b/src/llmcompressor/utils/metric_logging.py @@ -1,9 +1,11 @@ +import time from typing import List, Tuple +import torch from loguru import logger from torch.nn import Module -__all__ = ["get_GPU_memory_usage", "get_layer_size_mb"] +__all__ = ["get_GPU_memory_usage", "get_layer_size_mb", "CompressionLogger"] def get_GPU_memory_usage() -> List[Tuple]: @@ -51,3 +53,50 @@ def get_layer_size_mb(module: Module) -> float: total_size_mb = total_size / (1e6) # Convert bytes to MB return total_size_mb + + +class CompressionLogger: + """ + Log metrics related to compression algorithm + + :param start_tick: time when algorithm started" + :param losses: loss as result of algorithm + """ + + def __init__(self, module: torch.nn.Module): + self.module = module + self.start_tick = None + self.loss = None + + def set_loss(self, loss: float): + self.loss = loss + + def __enter__(self) -> "CompressionLogger": + self.start_tick = time.time() + return self + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + stop_tick = time.time() + patch = logger.patch(lambda r: r.update(function="compress")) + + if self.start_tick is not None: + duration = stop_tick - self.start_tick + patch.log("METRIC", f"time {duration:.2f}") + if self.loss is not None: + patch.log("METRIC", f"error {self.loss:.2f}") + + gpu_usage = get_GPU_memory_usage() + if len(gpu_usage) > 0: + for i in range(len(gpu_usage)): + perc = gpu_usage[i][0] * 100 + total_memory = int(gpu_usage[i][1]) # GB + patch.log( + "METRIC", + ( + f"GPU {i} | usage: {perc:.2f}%" + f" | total memory: {total_memory} GB" + ), + ) + + compressed_size = get_layer_size_mb(self.module) + patch.log("METRIC", f"Compressed module size: {compressed_size} MB") diff --git a/tests/llmcompressor/utils/test_helpers.py b/tests/llmcompressor/utils/test_helpers.py index dc9e44954..c1975f9d5 100644 --- a/tests/llmcompressor/utils/test_helpers.py +++ b/tests/llmcompressor/utils/test_helpers.py @@ -1,9 +1,12 @@ from types import SimpleNamespace import pytest +import torch from llmcompressor.utils import ( ALL_TOKEN, + DisableQuantization, + calibration_forward_context, convert_to_bool, flatten_iterable, getattr_chain, @@ -124,3 +127,24 @@ def test_getattr_chain(): assert getattr_chain(base, "b.d.dne", "default") == "default" with pytest.raises(AttributeError): getattr_chain(base, "b.d.dne") + + +def test_DisableQuantization(): + model = torch.nn.Linear(1, 1) + with DisableQuantization(model): + assert not model.quantization_enabled + assert model.quantization_enabled + + +def test_calibration_forward_context(): + model = torch.nn.Linear(1, 1) + model.config = SimpleNamespace() + model.config.use_cache = True + + with calibration_forward_context(model): + assert not torch.is_grad_enabled() + assert not model.quantization_enabled + assert not model.config.use_cache + assert torch.is_grad_enabled() + assert model.quantization_enabled + assert model.config.use_cache