@@ -729,25 +729,38 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float
729
729
assert w .shape [- 1 ] % groupsize == 0
730
730
assert w .dim () == 2
731
731
732
- to_quant = w .reshape (- 1 , groupsize )
733
- assert torch .isnan (to_quant ).sum () == 0
734
-
735
- max_val = to_quant .amax (dim = 1 , keepdim = True )
736
- min_val = to_quant .amin (dim = 1 , keepdim = True )
737
- min_val_neg = torch .min (min_val , torch .zeros_like (min_val ))
738
- max_val_pos = torch .max (max_val , torch .zeros_like (max_val ))
739
-
740
- max_val_abs = torch .max (- min_val_neg , max_val_pos )
741
- max_int = 2 ** (n_bit - 1 ) - 1
742
- min_int = - (2 ** (n_bit - 1 ))
743
-
744
- scales = max_val_abs / (float (max_int - min_int ) / 2 )
745
- scales = torch .max (scales , torch .full_like (scales , torch .finfo (torch .float32 ).eps ))
746
- # TODO: make sure abs(scales) is not too small?
747
- zeros = torch .full_like (scales , 0 )
748
- return scales .to (precision ).reshape (w .shape [0 ], - 1 ), zeros .to (precision ).reshape (
749
- w .shape [0 ], - 1
750
- )
732
+ block_size = (w .shape [0 ], groupsize )
733
+ mapping_type = MappingType .SYMMETRIC
734
+ eps = torch .finfo (torch .float32 ).eps
735
+ if TORCH_VERSION_AFTER_2_3 :
736
+ bit_to_dtype = {
737
+ 1 : torch .uint1 ,
738
+ 2 : torch .uint2 ,
739
+ 3 : torch .uint3 ,
740
+ 4 : torch .uint4 ,
741
+ 5 : torch .uint5 ,
742
+ 6 : torch .uint6 ,
743
+ 7 : torch .uint7 ,
744
+ 8 : torch .uint8 ,
745
+ }
746
+ assert n_bit in ranges , f"unsupported bit: { n_bit } "
747
+ target_dtype = bit_to_dtype [n_bit ]
748
+ return choose_qparams_affine (w , mapping_type , block_size , target_dtype = target_dtype , eps = eps , scale_dtype = precision , zero_point_dtype = precision )
749
+ else :
750
+ ranges = {
751
+ 1 : (0 , 2 ** 1 - 1 ),
752
+ 2 : (0 , 2 ** 2 - 1 ),
753
+ 3 : (0 , 2 ** 3 - 1 ),
754
+ 4 : (0 , 2 ** 4 - 1 ),
755
+ 5 : (0 , 2 ** 5 - 1 ),
756
+ 6 : (0 , 2 ** 6 - 1 ),
757
+ 7 : (0 , 2 ** 7 - 1 ),
758
+ 8 : (0 , 2 ** 8 - 1 ),
759
+ }
760
+ assert n_bit in ranges , f"unsupported bit: { n_bit } "
761
+ quant_min , quant_max = ranges [n_bit ]
762
+ # using uint8 to simulate uint4
763
+ return choose_qparams_affine (w , mapping_type , block_size , target_dtype = torch .uint8 , quant_min = 0 , quant_max = 15 , eps = eps , scale_dtype = precision , zero_point_dtype = precision )
751
764
752
765
753
766
if TORCH_VERSION_AFTER_2_3 :
0 commit comments