|
9 | 9 |
|
10 | 10 | import copy |
11 | 11 | import unittest |
| 12 | +from typing import List |
12 | 13 |
|
13 | 14 | import torch |
14 | 15 | import torch.nn.functional as F |
|
26 | 27 | from torchao.quantization.qat.api import ( |
27 | 28 | ComposableQATQuantizer, |
28 | 29 | FakeQuantizeConfig, |
| 30 | + IntXQuantizationAwareTrainingConfig, |
29 | 31 | from_intx_quantization_aware_training, |
| 32 | + initialize_fake_quantizers, |
30 | 33 | intx_quantization_aware_training, |
31 | 34 | ) |
32 | 35 | from torchao.quantization.qat.embedding import ( |
@@ -99,6 +102,16 @@ def __init__(self): |
99 | 102 | def example_inputs(self): |
100 | 103 | return (torch.randn(1, 512).to(torch.float),) |
101 | 104 |
|
| 105 | + def _get_all_weight_qparams(self) -> List[torch.Tensor]: |
| 106 | + return [ |
| 107 | + self.linear1.weight_fake_quantizer.scale, |
| 108 | + self.linear1.weight_fake_quantizer.zero_point, |
| 109 | + self.sub.linear.weight_fake_quantizer.scale, |
| 110 | + self.sub.linear.weight_fake_quantizer.zero_point, |
| 111 | + self.linear2.weight_fake_quantizer.scale, |
| 112 | + self.linear2.weight_fake_quantizer.zero_point, |
| 113 | + ] |
| 114 | + |
102 | 115 | def forward(self, x): |
103 | 116 | x = self.linear1(x) |
104 | 117 | x = self.sub(x) |
@@ -996,6 +1009,21 @@ def test_fake_quantize_config_dtype(self): |
996 | 1009 | FakeQuantizeConfig(TorchAODType.INT7, "per_token") |
997 | 1010 | FakeQuantizeConfig(torch.int8, "per_token") |
998 | 1011 |
|
| 1012 | + def test_fake_quantize_config_dynamic_and_range_learning(self): |
| 1013 | + """ |
| 1014 | + Test that `is_dynamic` and `range_learning` cannot both be set. |
| 1015 | + """ |
| 1016 | + FakeQuantizeConfig( |
| 1017 | + torch.int8, "per_channel", is_dynamic=True, range_learning=False |
| 1018 | + ) |
| 1019 | + FakeQuantizeConfig( |
| 1020 | + torch.int8, "per_channel", is_dynamic=False, range_learning=True |
| 1021 | + ) |
| 1022 | + with self.assertRaisesRegex(ValueError, "not compatible"): |
| 1023 | + FakeQuantizeConfig( |
| 1024 | + torch.int8, "per_channel", is_dynamic=True, range_learning=True |
| 1025 | + ) |
| 1026 | + |
999 | 1027 | @unittest.skipIf( |
1000 | 1028 | not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" |
1001 | 1029 | ) |
@@ -1591,6 +1619,95 @@ def test_qat_8da4w_eps(self): |
1591 | 1619 | actual_out = converted_model.linear1(x) |
1592 | 1620 | torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0) |
1593 | 1621 |
|
| 1622 | + @unittest.skipIf( |
| 1623 | + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" |
| 1624 | + ) |
| 1625 | + def test_fake_quantizer_range_learning(self): |
| 1626 | + """ |
| 1627 | + Test that range learning requires `FakeQuantizer`s to be initialized correctly. |
| 1628 | + """ |
| 1629 | + config = FakeQuantizeConfig( |
| 1630 | + torch.int8, |
| 1631 | + "per_channel", |
| 1632 | + is_dynamic=False, |
| 1633 | + range_learning=True, |
| 1634 | + scale_precision=torch.float32, |
| 1635 | + zero_point_precision=torch.float32, |
| 1636 | + ) |
| 1637 | + fake_quantizer = FakeQuantizer(config) |
| 1638 | + example_inputs = (torch.randn(2, 3),) |
| 1639 | + |
| 1640 | + # Not initialized, should fail |
| 1641 | + self.assertFalse(fake_quantizer._initialized) |
| 1642 | + self.assertIsNone(fake_quantizer.scale) |
| 1643 | + self.assertIsNone(fake_quantizer.zero_point) |
| 1644 | + with self.assertRaisesRegex( |
| 1645 | + ValueError, |
| 1646 | + "Please call `torchao.quantization.qat.initialize_fake_quantizers` " |
| 1647 | + "before initializing the optimizer and beginning training.", |
| 1648 | + ): |
| 1649 | + fake_quantizer(*example_inputs) |
| 1650 | + |
| 1651 | + # Should pass after initializing |
| 1652 | + initialize_fake_quantizers(fake_quantizer, example_inputs) |
| 1653 | + self.assertTrue(fake_quantizer._initialized) |
| 1654 | + self.assertIsInstance(fake_quantizer.scale, torch.nn.Parameter) |
| 1655 | + self.assertIsInstance(fake_quantizer.zero_point, torch.nn.Parameter) |
| 1656 | + self.assertTrue(fake_quantizer.scale.requires_grad) |
| 1657 | + self.assertTrue(fake_quantizer.zero_point.requires_grad) |
| 1658 | + fake_quantizer(*example_inputs) |
| 1659 | + |
| 1660 | + @unittest.skipIf( |
| 1661 | + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" |
| 1662 | + ) |
| 1663 | + def test_qat_range_learning(self): |
| 1664 | + """ |
| 1665 | + Test end-to-end QAT flow with range learning. |
| 1666 | + """ |
| 1667 | + config = FakeQuantizeConfig( |
| 1668 | + torch.int8, |
| 1669 | + "per_channel", |
| 1670 | + is_dynamic=False, |
| 1671 | + range_learning=True, |
| 1672 | + scale_precision=torch.float32, |
| 1673 | + zero_point_precision=torch.float32, |
| 1674 | + ) |
| 1675 | + m = M() |
| 1676 | + example_inputs = m.example_inputs() |
| 1677 | + quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config)) |
| 1678 | + |
| 1679 | + # Not initialized, should fail |
| 1680 | + for t in m._get_all_weight_qparams(): |
| 1681 | + self.assertIsNone(t) |
| 1682 | + with self.assertRaisesRegex( |
| 1683 | + ValueError, |
| 1684 | + "Please call `torchao.quantization.qat.initialize_fake_quantizers` " |
| 1685 | + "before initializing the optimizer and beginning training.", |
| 1686 | + ): |
| 1687 | + m(*example_inputs) |
| 1688 | + |
| 1689 | + # Should pass after initializing |
| 1690 | + # All scales and zero points should be in `m.parameters()` |
| 1691 | + initialize_fake_quantizers(m, example_inputs) |
| 1692 | + params = set(m.parameters()) |
| 1693 | + for t in m._get_all_weight_qparams(): |
| 1694 | + self.assertIsInstance(t, torch.nn.Parameter) |
| 1695 | + self.assertTrue(t.requires_grad) |
| 1696 | + self.assertTrue(t in params) |
| 1697 | + m(*example_inputs) |
| 1698 | + |
| 1699 | + # Simulate training |
| 1700 | + optimizer = torch.optim.SGD( |
| 1701 | + m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 |
| 1702 | + ) |
| 1703 | + loss_fn = torch.nn.CrossEntropyLoss() |
| 1704 | + target = torch.randn(1, 512).float() |
| 1705 | + out = m(*example_inputs) |
| 1706 | + loss = loss_fn(out, target) |
| 1707 | + optimizer.zero_grad() |
| 1708 | + loss.backward() |
| 1709 | + optimizer.step() |
| 1710 | + |
1594 | 1711 |
|
1595 | 1712 | if __name__ == "__main__": |
1596 | 1713 | unittest.main() |
0 commit comments