diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 031e5ef14..3a12d9b63 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -200,6 +200,104 @@ def test_qat_8da4w_quantizer(self): for k in ptq_state_dict.keys(): torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + def test_qat_8da4w_quantizer_disable_fake_quant(self): + """ + Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward. + """ + from torchao.quantization.prototype.qat import ( + Int8DynActInt4WeightQATQuantizer, + disable_8da4w_fake_quant, + enable_8da4w_fake_quant, + ) + + group_size = 16 + torch.manual_seed(self.SEED) + m = M() + m2 = copy.deepcopy(m) + m3 = copy.deepcopy(m) + quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + qat_model = quantizer.prepare(m) + qat_model.apply(disable_8da4w_fake_quant) + self.assertFalse(qat_model.linear1._fake_quant_enabled) + self.assertFalse(qat_model.linear2._fake_quant_enabled) + self.assertFalse(qat_model.sub.linear._fake_quant_enabled) + + # Disabled fake quant is just a normal linear + m2.linear1.weight = qat_model.linear1.weight + m2.linear2.weight = qat_model.linear2.weight + m2.sub.linear.weight = qat_model.sub.linear.weight + torch.manual_seed(self.SEED) + x = m.example_inputs() + x2 = copy.deepcopy(x) + qat_out = qat_model(*x) + nn_out = m2(*x2) + torch.testing.assert_close(nn_out, qat_out, atol=0, rtol=0) + + # Renable fake quant + qat_model.apply(enable_8da4w_fake_quant) + self.assertTrue(qat_model.linear1._fake_quant_enabled) + self.assertTrue(qat_model.linear2._fake_quant_enabled) + self.assertTrue(qat_model.sub.linear._fake_quant_enabled) + + # Fake quant should be applied as normal + quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + qat_model2 = quantizer2.prepare(m3) + qat_model2.linear1.weight = qat_model.linear1.weight + qat_model2.linear2.weight = qat_model.linear2.weight + qat_model2.sub.linear.weight = qat_model.sub.linear.weight + torch.manual_seed(self.SEED) + x = m.example_inputs() + x2 = copy.deepcopy(x) + qat_out = qat_model(*x) + qat_out2 = qat_model2(*x2) + torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): + """ + Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward. + """ + from torchao.quantization.prototype.qat import ( + Int8DynActInt4WeightQATQuantizer, + disable_8da4w_fake_quant, + ) + + group_size = 16 + torch.manual_seed(self.SEED) + m = M() + nn_model = copy.deepcopy(m) + quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + qat_model = quantizer.prepare(m) + qat_model.apply(disable_8da4w_fake_quant) + nn_model.linear1.weight = qat_model.linear1.weight + nn_model.linear2.weight = qat_model.linear2.weight + nn_model.sub.linear.weight = qat_model.sub.linear.weight + + # Simulate training for both models + optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) + optimizer2 = torch.optim.SGD(qat_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) + loss_fn1 = torch.nn.CrossEntropyLoss() + loss_fn2 = torch.nn.CrossEntropyLoss() + example_inputs = nn_model.example_inputs() + target = torch.randn(1, 64).float() + output1 = nn_model(*example_inputs) + output2 = qat_model(*example_inputs) + torch.testing.assert_close(output1, output2, atol=0, rtol=0) + loss1 = loss_fn1(output1, target) + loss2 = loss_fn2(output2, target) + optimizer1.zero_grad() + optimizer2.zero_grad() + loss1.backward() + loss2.backward() + optimizer1.step() + optimizer2.step() + + # After 1 training step, weights should match exactly + torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0) + torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0) + torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/prototype/qat.py b/torchao/quantization/prototype/qat.py index 7ba64f3ac..5bf01d55d 100644 --- a/torchao/quantization/prototype/qat.py +++ b/torchao/quantization/prototype/qat.py @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Tuple +from typing import Any, Optional, Tuple import torch from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib @@ -129,30 +129,43 @@ def __init__( self.groupsize = groupsize self.precision = precision self.scales_precision = scales_precision + self._fake_quant_enabled = True + + def enable_fake_quant(self, enabled: bool = True): + self._fake_quant_enabled = enabled + + def disable_fake_quant(self): + self.enable_fake_quant(False) def forward(self, x: torch.Tensor) -> torch.Tensor: # activations: int8 dynamic asymmetric quant - (act_qmin, act_qmax) = self._get_qmin_qmax(8) - (act_scales, act_zp) = _choose_qparams_per_token_asymmetric( - x, torch.int8, # dtype not used - ) - x_fq = fake_quantize_per_token( - x, act_scales, act_zp, act_qmin, act_qmax, - ) + if self._fake_quant_enabled: + (act_scales, act_zp) =_choose_qparams_per_token_asymmetric( + x, torch.int8, # dtype not used + ) + (act_qmin, act_qmax) = self._get_qmin_qmax(8) + x_fq = fake_quantize_per_token( + x, act_scales, act_zp, act_qmin, act_qmax, + ) + else: + x_fq = x # weights: int4 grouped per channel symmetric quant - (weight_qmin, weight_qmax) = self._get_qmin_qmax(4) - (weight_scales, weight_zp) = get_group_qparams_symmetric( - self.weight, 4, self.groupsize, self.scales_precision, - ) - w_fq = fake_quantize_per_channel_group( - self.weight, - weight_scales, - weight_zp, - weight_qmin, - weight_qmax, - self.groupsize, - ) + if self._fake_quant_enabled: + (weight_scales, weight_zp) = get_group_qparams_symmetric( + self.weight, 4, self.groupsize, self.scales_precision, + ) + (weight_qmin, weight_qmax) = self._get_qmin_qmax(4) + w_fq = fake_quantize_per_channel_group( + self.weight, + weight_scales, + weight_zp, + weight_qmin, + weight_qmax, + self.groupsize, + ) + else: + w_fq = self.weight return torch.nn.functional.linear(x_fq, w_fq) # TODO: move this to common util @@ -161,6 +174,20 @@ def _get_qmin_qmax(self, n_bit: int): qmax = 2 ** (n_bit - 1) - 1 return (qmin, qmax) + def enable_8da4w_fake_quant(mod: torch.nn.Module): + """ + Enable fake quantization for `Int8DynActInt4WeightQATLinear`. + """ + if isinstance(mod, Int8DynActInt4WeightQATLinear): + mod.enable_fake_quant() + + def disable_8da4w_fake_quant(mod: torch.nn.Module): + """ + Disable fake quantization for `Int8DynActInt4WeightQATLinear`. + """ + if isinstance(mod, Int8DynActInt4WeightQATLinear): + mod.disable_fake_quant() + # ======================== # | QUANT PRIMITIVES |