Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calibration and Compression Contexts #998

Merged
merged 2 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/llmcompressor/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import ast
import contextlib
import errno
import fnmatch
import glob
Expand All @@ -23,6 +24,7 @@

import numpy
import torch
from compressed_tensors.quantization import disable_quantization, enable_quantization
from loguru import logger

__all__ = [
Expand Down Expand Up @@ -61,6 +63,8 @@
"import_from_path",
"getattr_chain",
"DisableKVCache",
"DisableQuantization",
"calibration_forward_context",
]


Expand Down Expand Up @@ -1080,3 +1084,32 @@ def __enter__(self):

def __exit__(self, _exc_type, _exc_val, _exc_tb):
self.config.use_cache = self.restore_value


@contextlib.contextmanager
def DisableQuantization(model: torch.nn.Module):
"""
Disable quantization from QuantizationModifier
"""
model.apply(disable_quantization)
yield
model.apply(enable_quantization)


@contextlib.contextmanager
def calibration_forward_context(model: torch.nn.Module):
"""
Context in which all calibration forward passes should occur.

- Remove gradient calculations
- Disable the KV cache
- Disable quantization from QuantizationModifier
"""
model.eval()

with (
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
torch.no_grad(),
DisableKVCache(model),
dsikka marked this conversation as resolved.
Show resolved Hide resolved
DisableQuantization(model),
):
yield
51 changes: 50 additions & 1 deletion src/llmcompressor/utils/metric_logging.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import time
from typing import List, Tuple

import torch
from loguru import logger
from torch.nn import Module

__all__ = ["get_GPU_memory_usage", "get_layer_size_mb"]
__all__ = ["get_GPU_memory_usage", "get_layer_size_mb", "CompressionLogger"]


def get_GPU_memory_usage() -> List[Tuple]:
Expand Down Expand Up @@ -51,3 +53,50 @@ def get_layer_size_mb(module: Module) -> float:
total_size_mb = total_size / (1e6) # Convert bytes to MB

return total_size_mb


class CompressionLogger:
dsikka marked this conversation as resolved.
Show resolved Hide resolved
"""
Log metrics related to compression algorithm

:param start_tick: time when algorithm started"
:param losses: loss as result of algorithm
"""

def __init__(self, module: torch.nn.Module):
self.module = module
self.start_tick = None
self.loss = None

def set_loss(self, loss: float):
self.loss = loss

def __enter__(self) -> "CompressionLogger":
self.start_tick = time.time()
return self

def __exit__(self, _exc_type, _exc_val, _exc_tb):
stop_tick = time.time()
patch = logger.patch(lambda r: r.update(function="compress"))

if self.start_tick is not None:
duration = stop_tick - self.start_tick
patch.log("METRIC", f"time {duration:.2f}")
if self.loss is not None:
patch.log("METRIC", f"error {self.loss:.2f}")

gpu_usage = get_GPU_memory_usage()
if len(gpu_usage) > 0:
for i in range(len(gpu_usage)):
perc = gpu_usage[i][0] * 100
total_memory = int(gpu_usage[i][1]) # GB
patch.log(
"METRIC",
(
f"GPU {i} | usage: {perc:.2f}%"
f" | total memory: {total_memory} GB"
),
)

compressed_size = get_layer_size_mb(self.module)
patch.log("METRIC", f"Compressed module size: {compressed_size} MB")
24 changes: 24 additions & 0 deletions tests/llmcompressor/utils/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from types import SimpleNamespace

import pytest
import torch

from llmcompressor.utils import (
ALL_TOKEN,
DisableQuantization,
calibration_forward_context,
convert_to_bool,
flatten_iterable,
getattr_chain,
Expand Down Expand Up @@ -124,3 +127,24 @@ def test_getattr_chain():
assert getattr_chain(base, "b.d.dne", "default") == "default"
with pytest.raises(AttributeError):
getattr_chain(base, "b.d.dne")


def test_DisableQuantization():
model = torch.nn.Linear(1, 1)
with DisableQuantization(model):
assert not model.quantization_enabled
assert model.quantization_enabled


def test_calibration_forward_context():
model = torch.nn.Linear(1, 1)
model.config = SimpleNamespace()
model.config.use_cache = True

with calibration_forward_context(model):
assert not torch.is_grad_enabled()
assert not model.quantization_enabled
assert not model.config.use_cache
assert torch.is_grad_enabled()
assert model.quantization_enabled
assert model.config.use_cache
Loading