diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 892d70bad..60f38c792 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -119,7 +119,7 @@ def _set_ptq_weight( Set the weight to the quantized version of the given fp32 weights, for making linear outputs comparable with QAT. """ - n_bit = 2 + n_bit = 3 (qmin, qmax) = self._get_qmin_qmax(n_bit) (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index d91200e37..c26ca3470 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -1017,7 +1017,7 @@ def linear_forward_8da4w( # TODO: better API # weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed) - n_bit = 2 + n_bit = 3 quant_min = -(2 ** (n_bit - 1)) quant_max = 2 ** (n_bit - 1) - 1 w_dq = torch.ops.quantized_decomposed.dequantize_per_channel_group( @@ -1196,7 +1196,7 @@ def __init__( self.padding_allowed: bool = padding_allowed self.precision: torch.dtype = precision self.scales_precision: torch.dtype = scales_precision - print("PTQ is running 2-bit weight only quant!") + print("PTQ is running 3-bit weight only quant!") @torch.no_grad() def _create_quantized_state_dict( @@ -1237,7 +1237,7 @@ def _create_quantized_state_dict( zeros, ) = group_quantize_tensor_symmetric( weight.to(self.precision), - 2, # n_bit + 3, # n_bit self.groupsize, self.scales_precision, ) diff --git a/torchao/quantization/prototype/qat.py b/torchao/quantization/prototype/qat.py index a7a17d5ef..f8db80030 100644 --- a/torchao/quantization/prototype/qat.py +++ b/torchao/quantization/prototype/qat.py @@ -109,7 +109,7 @@ def _convert_qat_linear_8da4w( setattr(module, name, quantized_linear) # Load weights and qparams into quantized linear - n_bit = 2 + n_bit = 3 (qmin, qmax) = child._get_qmin_qmax(n_bit) (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize) q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( @@ -159,7 +159,7 @@ def __init__( # TODO: make this configurable? self.zero_points_precision = torch.int32 self._fake_quant_enabled = True - print("QAT: I'm runining 2-bit weight only quant!") + print("QAT: I'm runining 3-bit weight only quant!") def enable_fake_quant(self, enabled: bool = True): self._fake_quant_enabled = enabled @@ -181,14 +181,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # x_fq = x x_fq = x - # weights: int2 grouped per channel symmetric quant + # weights: int3 grouped per channel symmetric quant if self._fake_quant_enabled: (weight_scales, weight_zp) = get_group_qparams_symmetric( - self.weight, 2, self.groupsize, self.scales_precision, + self.weight, 3, self.groupsize, self.scales_precision, ) # TODO: pass zp dtype to `get_group_qparams_symmetric` instead weight_zp = weight_zp.to(self.zero_points_precision) - (weight_qmin, weight_qmax) = self._get_qmin_qmax(2) + (weight_qmin, weight_qmax) = self._get_qmin_qmax(3) w_fq = fake_quantize_per_channel_group( self.weight, weight_scales,