|
9 | 9 |
|
10 | 10 | import copy |
11 | 11 | import unittest |
| 12 | +from typing import List |
12 | 13 |
|
13 | 14 | import torch |
14 | 15 | import torch.nn.functional as F |
|
26 | 27 | from torchao.quantization.qat.api import ( |
27 | 28 | ComposableQATQuantizer, |
28 | 29 | FakeQuantizeConfig, |
| 30 | + IntXQuantizationAwareTrainingConfig, |
29 | 31 | from_intx_quantization_aware_training, |
| 32 | + initialize_fake_quantizers, |
30 | 33 | intx_quantization_aware_training, |
31 | 34 | ) |
32 | 35 | from torchao.quantization.qat.embedding import ( |
@@ -99,6 +102,16 @@ def __init__(self): |
99 | 102 | def example_inputs(self): |
100 | 103 | return (torch.randn(1, 512).to(torch.float),) |
101 | 104 |
|
| 105 | + def _get_all_weight_qparams(self) -> List[torch.Tensor]: |
| 106 | + return [ |
| 107 | + self.linear1.weight_fake_quantizer.scale, |
| 108 | + self.linear1.weight_fake_quantizer.zero_point, |
| 109 | + self.sub.linear.weight_fake_quantizer.scale, |
| 110 | + self.sub.linear.weight_fake_quantizer.zero_point, |
| 111 | + self.linear2.weight_fake_quantizer.scale, |
| 112 | + self.linear2.weight_fake_quantizer.zero_point, |
| 113 | + ] |
| 114 | + |
102 | 115 | def forward(self, x): |
103 | 116 | x = self.linear1(x) |
104 | 117 | x = self.sub(x) |
@@ -996,6 +1009,21 @@ def test_fake_quantize_config_dtype(self): |
996 | 1009 | FakeQuantizeConfig(TorchAODType.INT7, "per_token") |
997 | 1010 | FakeQuantizeConfig(torch.int8, "per_token") |
998 | 1011 |
|
| 1012 | + def test_fake_quantize_config_dynamic_and_range_learning(self): |
| 1013 | + """ |
| 1014 | + Test that `is_dynamic` and `range_learning` cannot both be set. |
| 1015 | + """ |
| 1016 | + FakeQuantizeConfig( |
| 1017 | + torch.int8, "per_channel", is_dynamic=True, range_learning=False |
| 1018 | + ) |
| 1019 | + FakeQuantizeConfig( |
| 1020 | + torch.int8, "per_channel", is_dynamic=False, range_learning=True |
| 1021 | + ) |
| 1022 | + with self.assertRaisesRegex(ValueError, "not compatible"): |
| 1023 | + FakeQuantizeConfig( |
| 1024 | + torch.int8, "per_channel", is_dynamic=True, range_learning=True |
| 1025 | + ) |
| 1026 | + |
999 | 1027 | @unittest.skipIf( |
1000 | 1028 | not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" |
1001 | 1029 | ) |
@@ -1513,6 +1541,95 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype): |
1513 | 1541 | ) |
1514 | 1542 | self.assertEqual(len(non_inf_sqnr), 0, fail_message) |
1515 | 1543 |
|
| 1544 | + @unittest.skipIf( |
| 1545 | + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" |
| 1546 | + ) |
| 1547 | + def test_fake_quantizer_range_learning(self): |
| 1548 | + """ |
| 1549 | + Test that range learning requires `FakeQuantizer`s to be initialized correctly. |
| 1550 | + """ |
| 1551 | + config = FakeQuantizeConfig( |
| 1552 | + torch.int8, |
| 1553 | + "per_channel", |
| 1554 | + is_dynamic=False, |
| 1555 | + range_learning=True, |
| 1556 | + scale_precision=torch.float32, |
| 1557 | + zero_point_precision=torch.float32, |
| 1558 | + ) |
| 1559 | + fake_quantizer = FakeQuantizer(config) |
| 1560 | + example_inputs = (torch.randn(2, 3),) |
| 1561 | + |
| 1562 | + # Not initialized, should fail |
| 1563 | + self.assertFalse(fake_quantizer._initialized) |
| 1564 | + self.assertIsNone(fake_quantizer.scale) |
| 1565 | + self.assertIsNone(fake_quantizer.zero_point) |
| 1566 | + with self.assertRaisesRegex( |
| 1567 | + ValueError, |
| 1568 | + "Please call `torchao.quantization.qat.initialize_fake_quantizers` " |
| 1569 | + "before initializing the optimizer and beginning training.", |
| 1570 | + ): |
| 1571 | + fake_quantizer(*example_inputs) |
| 1572 | + |
| 1573 | + # Should pass after initializing |
| 1574 | + initialize_fake_quantizers(fake_quantizer, example_inputs) |
| 1575 | + self.assertTrue(fake_quantizer._initialized) |
| 1576 | + self.assertIsInstance(fake_quantizer.scale, torch.nn.Parameter) |
| 1577 | + self.assertIsInstance(fake_quantizer.zero_point, torch.nn.Parameter) |
| 1578 | + self.assertTrue(fake_quantizer.scale.requires_grad) |
| 1579 | + self.assertTrue(fake_quantizer.zero_point.requires_grad) |
| 1580 | + fake_quantizer(*example_inputs) |
| 1581 | + |
| 1582 | + @unittest.skipIf( |
| 1583 | + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" |
| 1584 | + ) |
| 1585 | + def test_qat_range_learning(self): |
| 1586 | + """ |
| 1587 | + Test end-to-end QAT flow with range learning. |
| 1588 | + """ |
| 1589 | + config = FakeQuantizeConfig( |
| 1590 | + torch.int8, |
| 1591 | + "per_channel", |
| 1592 | + is_dynamic=False, |
| 1593 | + range_learning=True, |
| 1594 | + scale_precision=torch.float32, |
| 1595 | + zero_point_precision=torch.float32, |
| 1596 | + ) |
| 1597 | + m = M() |
| 1598 | + example_inputs = m.example_inputs() |
| 1599 | + quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config)) |
| 1600 | + |
| 1601 | + # Not initialized, should fail |
| 1602 | + for t in m._get_all_weight_qparams(): |
| 1603 | + self.assertIsNone(t) |
| 1604 | + with self.assertRaisesRegex( |
| 1605 | + ValueError, |
| 1606 | + "Please call `torchao.quantization.qat.initialize_fake_quantizers` " |
| 1607 | + "before initializing the optimizer and beginning training.", |
| 1608 | + ): |
| 1609 | + m(*example_inputs) |
| 1610 | + |
| 1611 | + # Should pass after initializing |
| 1612 | + # All scales and zero points should be in `m.parameters()` |
| 1613 | + initialize_fake_quantizers(m, example_inputs) |
| 1614 | + params = set(m.parameters()) |
| 1615 | + for t in m._get_all_weight_qparams(): |
| 1616 | + self.assertIsInstance(t, torch.nn.Parameter) |
| 1617 | + self.assertTrue(t.requires_grad) |
| 1618 | + self.assertTrue(t in params) |
| 1619 | + m(*example_inputs) |
| 1620 | + |
| 1621 | + # Simulate training |
| 1622 | + optimizer = torch.optim.SGD( |
| 1623 | + m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 |
| 1624 | + ) |
| 1625 | + loss_fn = torch.nn.CrossEntropyLoss() |
| 1626 | + target = torch.randn(1, 512).float() |
| 1627 | + out = m(*example_inputs) |
| 1628 | + loss = loss_fn(out, target) |
| 1629 | + optimizer.zero_grad() |
| 1630 | + loss.backward() |
| 1631 | + optimizer.step() |
| 1632 | + |
1516 | 1633 |
|
1517 | 1634 | if __name__ == "__main__": |
1518 | 1635 | unittest.main() |
0 commit comments