Skip to content

Commit 729028a

Browse files
committed
update to test 8 bits; update kwargs
1 parent 87bd392 commit 729028a

File tree

3 files changed

+13
-17
lines changed

3 files changed

+13
-17
lines changed

src/compressed_tensors/compressors/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor
4545
raise NotImplementedError()
4646

4747
def decompress(
48-
self, path_to_model_or_tensors: str, device: str = "cpu"
48+
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
4949
) -> Generator[Tuple[str, Tensor], None, None]:
5050
"""
5151
Reads a compressed state dict located at path_to_model_or_tensors

src/compressed_tensors/compressors/model_compressor.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
SPARSITY_CONFIG_NAME,
2626
)
2727
from compressed_tensors.compressors import Compressor
28-
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
28+
from compressed_tensors.config import SparsityCompressionConfig
2929
from compressed_tensors.quantization import (
3030
QuantizationConfig,
3131
QuantizationStatus,
@@ -254,15 +254,9 @@ def decompress(self, model_path: str, model: Module):
254254
if self.quantization_compressor is not None:
255255
names_to_scheme = apply_quantization_config(model, self.quantization_config)
256256
load_pretrained_quantization(model, model_path)
257-
if (
258-
self.quantization_config.format
259-
== CompressionFormat.pack_quantized.value
260-
):
261-
dense_gen = self.quantization_compressor.decompress(
262-
model_path, names_to_scheme=names_to_scheme
263-
)
264-
else:
265-
dense_gen = self.quantization_compressor.decompress(model_path)
257+
dense_gen = self.quantization_compressor.decompress(
258+
model_path, names_to_scheme=names_to_scheme
259+
)
266260
self._replace_weights(dense_gen, model)
267261

268262
def update_status(module):

tests/test_compressors/test_pack_quant.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
from safetensors.torch import save_file
3333

3434

35-
def get_dummy_quant_config():
35+
def get_dummy_quant_config(num_bits=4):
3636
config_groups = {
3737
"group_1": QuantizationScheme(
38-
targets=["Linear"], weights=QuantizationArgs(num_bits=4)
38+
targets=["Linear"], weights=QuantizationArgs(num_bits=num_bits)
3939
),
4040
}
4141
ignore = ["lm_head"]
@@ -106,7 +106,8 @@ def test_repack(value):
106106
assert torch.equal(value, unpacked)
107107

108108

109-
def test_reload_match(tmp_path):
109+
@pytest.mark.parametrize("num_bits", [4, 8])
110+
def test_reload_match(tmp_path, num_bits):
110111
dense_state_dict = {
111112
"dummy.weight": torch.rand((511, 350)),
112113
"dummy.weight_scale": torch.tensor(0.01, dtype=torch.float32),
@@ -115,11 +116,12 @@ def test_reload_match(tmp_path):
115116
"dummy2.weight_scale": torch.tensor(0.02, dtype=torch.float32),
116117
"dummy2.weight_zero_point": torch.tensor(15, dtype=torch.int32),
117118
}
119+
print("num bits", num_bits)
118120
names_to_scheme = {
119-
"dummy": QuantizationArgs(num_bits=4),
120-
"dummy2": QuantizationArgs(num_bits=4),
121+
"dummy": QuantizationArgs(num_bits=num_bits),
122+
"dummy2": QuantizationArgs(num_bits=num_bits),
121123
}
122-
quant_config = get_dummy_quant_config()
124+
quant_config = get_dummy_quant_config(num_bits)
123125

124126
compressor = PackedQuantizationCompressor(config=quant_config)
125127
quantized_modules_to_args = {

0 commit comments

Comments
 (0)