File tree Expand file tree Collapse file tree 2 files changed +4
-4
lines changed
src/compressed_tensors/compressors Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -56,14 +56,14 @@ class PackedQuantizationCompressor(Compressor):
5656 def compress (
5757 self ,
5858 model_state : Dict [str , Tensor ],
59- model_quant_args : Dict [str , QuantizationArgs ],
59+ names_to_scheme : Dict [str , QuantizationArgs ],
6060 ** kwargs ,
6161 ) -> Dict [str , Tensor ]:
6262 """
6363 Compresses a dense state dict
6464
6565 :param model_state: state dict of uncompressed model
66- :param model_quant_args : quantization args for each quantized weight, needed for
66+ :param names_to_scheme : quantization args for each quantized weight, needed for
6767 quantize function to calculate bit depth
6868 :return: compressed state dict
6969 """
@@ -81,7 +81,7 @@ def compress(
8181 shape = torch .tensor (value .shape )
8282 if scale is not None and zp is not None :
8383 # weight is quantized, compress it
84- quant_args = model_quant_args [prefix ]
84+ quant_args = names_to_scheme [prefix ]
8585 if can_quantize (value , quant_args ):
8686 # convert weight to an int if not already compressed
8787 value = quantize (
Original file line number Diff line number Diff line change @@ -116,7 +116,7 @@ def test_reload_match(tmp_path, num_bits):
116116 "dummy2.weight_scale" : torch .tensor (0.02 , dtype = torch .float32 ),
117117 "dummy2.weight_zero_point" : torch .tensor (15 , dtype = torch .int32 ),
118118 }
119- print ( "num bits" , num_bits )
119+
120120 names_to_scheme = {
121121 "dummy" : QuantizationArgs (num_bits = num_bits ),
122122 "dummy2" : QuantizationArgs (num_bits = num_bits ),
You can’t perform that action at this time.
0 commit comments