Skip to content

Commit

Permalink
add utils (#998)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
  • Loading branch information
kylesayrs and dsikka committed Dec 23, 2024
1 parent 93e6020 commit 6830a0f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/llmcompressor/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
"import_from_path",
"getattr_chain",
"DisableKVCache",
"DisableQuantization",
"calibration_forward_context",
]


Expand Down
24 changes: 24 additions & 0 deletions tests/llmcompressor/utils/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from types import SimpleNamespace

import pytest
import torch

from llmcompressor.utils import (
ALL_TOKEN,
DisableQuantization,
calibration_forward_context,
convert_to_bool,
flatten_iterable,
getattr_chain,
Expand Down Expand Up @@ -124,3 +127,24 @@ def test_getattr_chain():
assert getattr_chain(base, "b.d.dne", "default") == "default"
with pytest.raises(AttributeError):
getattr_chain(base, "b.d.dne")


def test_DisableQuantization():
model = torch.nn.Linear(1, 1)
with DisableQuantization(model):
assert not model.quantization_enabled
assert model.quantization_enabled


def test_calibration_forward_context():
model = torch.nn.Linear(1, 1)
model.config = SimpleNamespace()
model.config.use_cache = True

with calibration_forward_context(model):
assert not torch.is_grad_enabled()
assert not model.quantization_enabled
assert not model.config.use_cache
assert torch.is_grad_enabled()
assert model.quantization_enabled
assert model.config.use_cache

0 comments on commit 6830a0f

Please sign in to comment.