From 045b99e1463d44189980bccef74ad0967c390ca7 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 9 Apr 2025 16:09:58 -0700 Subject: [PATCH] Enable range learning for QAT **Summary:** This commit adds the option for QAT users to use range learning during training. Range learning means we train the scale and zero point instead of recomputing them based on the input at every iteration. Example usage: ``` import torch from torchao.quantization import quantize_ from torchao.quantization.qat import ( FakeQuantizeConfig, IntXQuantizationAwareTrainingConfig, initialize_fake_quantizers, ) config = FakeQuantizeConfig( torch.int8, "per_channel", is_dynamic=False, range_learning=True, scale_precision=torch.float32, zero_point_precision=torch.float32, ) m = M() example_inputs = (torch.randn(16, 32),) quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config)) # New required step to turn scales and zero points into trainable # `nn.Parameters`, must be called before initializing the optimizer initialize_fake_quantizers(m, example_inputs) # initialize the optimizer # do training ``` **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_dynamic_and_range_learning python test/quantization/test_qat.py -k test_fake_quantizer_range_learning python test/quantization/test_qat.py -k test_qat_range_learning --- test/quantization/test_qat.py | 117 +++++++++++++++++++++ third_party/cutlass | 2 +- torchao/quantization/qat/__init__.py | 8 +- torchao/quantization/qat/api.py | 29 ++++- torchao/quantization/qat/embedding.py | 2 + torchao/quantization/qat/fake_quantizer.py | 48 +++++++-- torchao/quantization/qat/linear.py | 16 +-- torchao/quantization/qat/utils.py | 14 +++ 8 files changed, 217 insertions(+), 19 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index d655abaf62..7444c3dbb5 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -9,6 +9,7 @@ import copy import unittest +from typing import List import torch import torch.nn.functional as F @@ -26,7 +27,9 @@ from torchao.quantization.qat.api import ( ComposableQATQuantizer, FakeQuantizeConfig, + IntXQuantizationAwareTrainingConfig, from_intx_quantization_aware_training, + initialize_fake_quantizers, intx_quantization_aware_training, ) from torchao.quantization.qat.embedding import ( @@ -99,6 +102,16 @@ def __init__(self): def example_inputs(self): return (torch.randn(1, 512).to(torch.float),) + def _get_all_weight_qparams(self) -> List[torch.Tensor]: + return [ + self.linear1.weight_fake_quantizer.scale, + self.linear1.weight_fake_quantizer.zero_point, + self.sub.linear.weight_fake_quantizer.scale, + self.sub.linear.weight_fake_quantizer.zero_point, + self.linear2.weight_fake_quantizer.scale, + self.linear2.weight_fake_quantizer.zero_point, + ] + def forward(self, x): x = self.linear1(x) x = self.sub(x) @@ -996,6 +1009,21 @@ def test_fake_quantize_config_dtype(self): FakeQuantizeConfig(TorchAODType.INT7, "per_token") FakeQuantizeConfig(torch.int8, "per_token") + def test_fake_quantize_config_dynamic_and_range_learning(self): + """ + Test that `is_dynamic` and `range_learning` cannot both be set. + """ + FakeQuantizeConfig( + torch.int8, "per_channel", is_dynamic=True, range_learning=False + ) + FakeQuantizeConfig( + torch.int8, "per_channel", is_dynamic=False, range_learning=True + ) + with self.assertRaisesRegex(ValueError, "not compatible"): + FakeQuantizeConfig( + torch.int8, "per_channel", is_dynamic=True, range_learning=True + ) + @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) @@ -1591,6 +1619,95 @@ def test_qat_8da4w_eps(self): actual_out = converted_model.linear1(x) torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_fake_quantizer_range_learning(self): + """ + Test that range learning requires `FakeQuantizer`s to be initialized correctly. + """ + config = FakeQuantizeConfig( + torch.int8, + "per_channel", + is_dynamic=False, + range_learning=True, + scale_precision=torch.float32, + zero_point_precision=torch.float32, + ) + fake_quantizer = FakeQuantizer(config) + example_inputs = (torch.randn(2, 3),) + + # Not initialized, should fail + self.assertFalse(fake_quantizer._initialized) + self.assertIsNone(fake_quantizer.scale) + self.assertIsNone(fake_quantizer.zero_point) + with self.assertRaisesRegex( + ValueError, + "Please call `torchao.quantization.qat.initialize_fake_quantizers` " + "before initializing the optimizer and beginning training.", + ): + fake_quantizer(*example_inputs) + + # Should pass after initializing + initialize_fake_quantizers(fake_quantizer, example_inputs) + self.assertTrue(fake_quantizer._initialized) + self.assertIsInstance(fake_quantizer.scale, torch.nn.Parameter) + self.assertIsInstance(fake_quantizer.zero_point, torch.nn.Parameter) + self.assertTrue(fake_quantizer.scale.requires_grad) + self.assertTrue(fake_quantizer.zero_point.requires_grad) + fake_quantizer(*example_inputs) + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_qat_range_learning(self): + """ + Test end-to-end QAT flow with range learning. + """ + config = FakeQuantizeConfig( + torch.int8, + "per_channel", + is_dynamic=False, + range_learning=True, + scale_precision=torch.float32, + zero_point_precision=torch.float32, + ) + m = M() + example_inputs = m.example_inputs() + quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config)) + + # Not initialized, should fail + for t in m._get_all_weight_qparams(): + self.assertIsNone(t) + with self.assertRaisesRegex( + ValueError, + "Please call `torchao.quantization.qat.initialize_fake_quantizers` " + "before initializing the optimizer and beginning training.", + ): + m(*example_inputs) + + # Should pass after initializing + # All scales and zero points should be in `m.parameters()` + initialize_fake_quantizers(m, example_inputs) + params = set(m.parameters()) + for t in m._get_all_weight_qparams(): + self.assertIsInstance(t, torch.nn.Parameter) + self.assertTrue(t.requires_grad) + self.assertTrue(t in params) + m(*example_inputs) + + # Simulate training + optimizer = torch.optim.SGD( + m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 + ) + loss_fn = torch.nn.CrossEntropyLoss() + target = torch.randn(1, 512).float() + out = m(*example_inputs) + loss = loss_fn(out, target) + optimizer.zero_grad() + loss.backward() + optimizer.step() + if __name__ == "__main__": unittest.main() diff --git a/third_party/cutlass b/third_party/cutlass index e94e888df3..afa1772203 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit e94e888df3551224738bfa505787b515eae8352f +Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py index 5dc3d8e008..010ccfc8cc 100644 --- a/torchao/quantization/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -4,6 +4,7 @@ FromIntXQuantizationAwareTrainingConfig, IntXQuantizationAwareTrainingConfig, from_intx_quantization_aware_training, + initialize_fake_quantizers, intx_quantization_aware_training, ) from .embedding import ( @@ -17,11 +18,12 @@ __all__ = [ "ComposableQATQuantizer", "FakeQuantizeConfig", - "Int4WeightOnlyQATQuantizer", + "FromIntXQuantizationAwareTrainingConfig", "Int4WeightOnlyEmbeddingQATQuantizer", + "Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer", + "IntXQuantizationAwareTrainingConfig", + "initialize_fake_quantizers", "intx_quantization_aware_training", "from_intx_quantization_aware_training", - "FromIntXQuantizationAwareTrainingConfig", - "IntXQuantizationAwareTrainingConfig", ] diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index e025a43d94..8fba195363 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Tuple, Union import torch @@ -51,7 +51,8 @@ class FakeQuantizeConfig: zero_point_precision: zero point dtype (default torch.int32) zero_point_domain: whether zero point is in integer (default) or float domain is_dynamic: whether to use dynamic (default) or static scale and zero points - range_learning: whether to learn scale and zero points during training (coming soon) + range_learning: whether to learn scale and zero points during training + (default false), not compatible with `is_dynamic`. kwargs (optional): group_size: size of each group in per group fake quantization, @@ -123,6 +124,10 @@ def __init__( "Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes) ) + # Dynamic is not compatible with range learning + if is_dynamic and range_learning: + raise ValueError("`is_dynamic` is not compatible with `range_learning`") + def _get_granularity( self, granularity: Union[Granularity, str, None], @@ -394,3 +399,23 @@ def convert( for quantizer in self.quantizers: model = quantizer.convert(model) return model + + +def initialize_fake_quantizers( + model: torch.nn.Module, + example_inputs: Tuple[Any, ...], +) -> None: + """ + Initialize the scales and zero points on all + :class:`~`torchao.quantization.qat.fake_quantizer.FakeQuantizer` + in the model based on the provided example inputs. + """ + # avoid circular dependencies + from torchao.quantization.qat.fake_quantizer import FakeQuantizer + + def _set_initialized(m: torch.nn.Module): + if isinstance(m, FakeQuantizer): + m._initialized = True + + model.apply(_set_initialized) + model(*example_inputs) diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py index 2770956a2c..aec23712ed 100644 --- a/torchao/quantization/qat/embedding.py +++ b/torchao/quantization/qat/embedding.py @@ -92,6 +92,7 @@ def to_embedding(self) -> torch.nn.Embedding: self.scale_grad_by_freq, self.sparse, device=self.weight.device, + dtype=self.weight.dtype, ) # In distributed training, the model may be instantiated # on the meta device, in which case there is no need to @@ -116,6 +117,7 @@ def from_embedding( mod.sparse, weight_config=weight_config, device=mod.weight.device, + dtype=mod.weight.dtype, ) # In distributed training, the model may be instantiated # on the meta device, in which case there is no need to diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index 0d2521cac0..90206b5d6e 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -31,6 +31,7 @@ from .utils import ( _fake_quantize_per_channel_group, _fake_quantize_per_token, + _Round, ) @@ -46,11 +47,12 @@ def __init__(self, config: FakeQuantizeConfig): self.scale: Optional[torch.Tensor] = None self.zero_point: Optional[torch.Tensor] = None - # TODO: support range learinng - if self.config.range_learning: - raise NotImplementedError("Range learning is not supported yet") + # For range learning only + # TODO: make this configurable? + self._scale_eps = 1e-9 + self._initialized = False - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Apply fake quantization to the tensor based on the bit-width, granularity, symmetry, and other properties specified in the config. @@ -58,6 +60,17 @@ def forward(self, x: torch.Tensor): if not self.enabled: return x + if ( + self.config.range_learning + and not self._initialized + and (self.scale is None or self.zero_point is None) + ): + raise ValueError( + "Scales and zero points must be initialized for range learning. " + "Please call `torchao.quantization.qat.initialize_fake_quantizers` " + "before initializing the optimizer and beginning training." + ) + if isinstance(self.config.granularity, PerToken): return self._per_token_forward(x) elif isinstance(self.config.granularity, (PerAxis, PerGroup)): @@ -65,13 +78,12 @@ def forward(self, x: torch.Tensor): else: raise ValueError("Unknown granularity '%s'" % self.config.granularity) - def _per_token_forward(self, x: torch.Tensor): + def _per_token_forward(self, x: torch.Tensor) -> torch.Tensor: """ Perform per token fake quantization on the tensor. """ if self.config.is_symmetric: raise NotImplementedError("Symmetric per token is not supported yet") - qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype] if self._should_compute_qparams(): self.scale, self.zero_point = choose_qparams_affine( @@ -85,9 +97,10 @@ def _per_token_forward(self, x: torch.Tensor): scale_dtype=self.config.scale_precision, zero_point_dtype=self.config.zero_point_precision, ) + self._maybe_update_qparams_for_range_learning() return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax) - def _per_channel_or_group_forward(self, x: torch.Tensor): + def _per_channel_or_group_forward(self, x: torch.Tensor) -> torch.Tensor: """ Perform per channel or per group fake quantization on the tensor. We express per channel using per group where the group size is the size @@ -129,6 +142,7 @@ def _per_channel_or_group_forward(self, x: torch.Tensor): eps=self.config.eps, ) self.zero_point = self.zero_point.to(zero_point_precision) + self._maybe_update_qparams_for_range_learning() qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype] return _fake_quantize_per_channel_group( @@ -147,6 +161,26 @@ def _should_compute_qparams(self) -> bool: """ return self.config.is_dynamic or self.scale is None or self.zero_point is None + def _maybe_update_qparams_for_range_learning(self) -> None: + """ + If range learning is enabled, turn scales and zero points into trainable parameters. + This function is idempotent and should only be called once. + """ + if ( + not self.config.range_learning + or isinstance(self.scale, torch.nn.Parameter) + or isinstance(self.zero_point, torch.nn.Parameter) + ): + return + scale, zero_point = self.scale, self.zero_point + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype] + # Stabilize range learning + scale = torch.clamp(scale, min=self._scale_eps) + zero_point = _Round.apply(zero_point) + zero_point = torch.clamp(zero_point, qmin, qmax) + self.scale = torch.nn.Parameter(scale, requires_grad=True) + self.zero_point = torch.nn.Parameter(zero_point, requires_grad=True) + def __repr__(self) -> str: """ Return a human readable representation of this `FakeQuantizer` with config details. diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index a912f04b83..7c32bc4b19 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -18,6 +18,7 @@ _replace_linear_int4, groupwise_affine_quantize_tensor, ) +from torchao.quantization.granularity import PerGroup from torchao.quantization.quant_primitives import ( TorchAODType, ZeroPointDomain, @@ -83,12 +84,13 @@ def __init__( # initialize weight fake quantizer if weight_config is not None: - group_size = weight_config.group_size - if group_size is not None and in_features % group_size != 0: - raise ValueError( - "in_features (%s) %% group_size (%s) must be == 0" - % (in_features, group_size) - ) + if isinstance(weight_config.granularity, PerGroup): + group_size = weight_config.group_size + if group_size is not None and in_features % group_size != 0: + raise ValueError( + "in_features (%s) %% group_size (%s) must be == 0" + % (in_features, group_size) + ) self.weight_fake_quantizer = FakeQuantizer(weight_config) else: self.weight_fake_quantizer = None @@ -108,6 +110,7 @@ def to_linear(self) -> torch.nn.Linear: self.out_features, self.bias is not None, device=self.weight.device, + dtype=self.weight.dtype, ) # In distributed training, the model may be instantiated # on the meta device, in which case there is no need to @@ -131,6 +134,7 @@ def from_linear( activation_config=activation_config, weight_config=weight_config, device=mod.weight.device, + dtype=mod.weight.dtype, ) # In distributed training, the model may be instantiated # on the meta device, in which case there is no need to diff --git a/torchao/quantization/qat/utils.py b/torchao/quantization/qat/utils.py index 12e9097ada..71e9a96ec5 100644 --- a/torchao/quantization/qat/utils.py +++ b/torchao/quantization/qat/utils.py @@ -91,6 +91,20 @@ def backward(ctx, gy): return (gy,) +class _Round(torch.autograd.Function): + """ + Implementation of generic round operation with backward STE. + """ + + @staticmethod + def forward(ctx, x: torch.Tensor) -> torch.Tensor: + return torch.round(x) + + @staticmethod + def backward(ctx, gy: torch.Tensor) -> torch.Tensor: + return gy + + def _fake_quantize_per_channel_group( input: torch.Tensor, scales: torch.Tensor,