Skip to content

Commit d03c14f

Browse files
committed
fix print; update name
1 parent 729028a commit d03c14f

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/compressed_tensors/compressors/pack_quantized.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ class PackedQuantizationCompressor(Compressor):
5656
def compress(
5757
self,
5858
model_state: Dict[str, Tensor],
59-
model_quant_args: Dict[str, QuantizationArgs],
59+
names_to_scheme: Dict[str, QuantizationArgs],
6060
**kwargs,
6161
) -> Dict[str, Tensor]:
6262
"""
6363
Compresses a dense state dict
6464
6565
:param model_state: state dict of uncompressed model
66-
:param model_quant_args: quantization args for each quantized weight, needed for
66+
:param names_to_scheme: quantization args for each quantized weight, needed for
6767
quantize function to calculate bit depth
6868
:return: compressed state dict
6969
"""
@@ -81,7 +81,7 @@ def compress(
8181
shape = torch.tensor(value.shape)
8282
if scale is not None and zp is not None:
8383
# weight is quantized, compress it
84-
quant_args = model_quant_args[prefix]
84+
quant_args = names_to_scheme[prefix]
8585
if can_quantize(value, quant_args):
8686
# convert weight to an int if not already compressed
8787
value = quantize(

tests/test_compressors/test_pack_quant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def test_reload_match(tmp_path, num_bits):
116116
"dummy2.weight_scale": torch.tensor(0.02, dtype=torch.float32),
117117
"dummy2.weight_zero_point": torch.tensor(15, dtype=torch.int32),
118118
}
119-
print("num bits", num_bits)
119+
120120
names_to_scheme = {
121121
"dummy": QuantizationArgs(num_bits=num_bits),
122122
"dummy2": QuantizationArgs(num_bits=num_bits),

0 commit comments

Comments
 (0)