From 75c6c8da269d328a32ad869998db15bb79606628 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 17 Sep 2025 10:23:17 -0700 Subject: [PATCH] Pass QAT learned qparams in convert **Summary:** Add support to pass scales and zero points learned during QAT range learning to the PTQ base config. Currently only the following configs support this feature: ``` IntxWeightOnlyConfig Int8DynamicActivationInt4WeightConfig Int8DynamicActivationIntxWeightConfig ``` During the convert phase, QAT will detect if range learning was used during training, and pass the learned scales and zero points as custom qparams to the quantized tensor subclass, so PTQ will produce more consistent numerics. Fixes part of https://github.com/pytorch/ao/issues/2271. **Test Plan:** ``` python test/quantization/test_qat.py -k test_range_learning_convert_pass_qparams ``` --- test/quantization/test_qat.py | 53 +++++++++++++++- torchao/dtypes/affine_quantized_tensor.py | 11 +++- torchao/quantization/qat/api.py | 32 ++++++++-- torchao/quantization/quant_api.py | 60 ++++++++++++++++--- .../intx/intx_unpacked_to_int8_tensor.py | 27 ++++++--- 5 files changed, 159 insertions(+), 24 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 64d2f02b2d..a6ef09e6e8 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -10,7 +10,7 @@ import copy import unittest import warnings -from typing import List +from typing import List, Type import torch import torch.nn.functional as F @@ -2304,6 +2304,57 @@ def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: rtol=0, ) + @parametrize( + "base_config_cls", + [ + IntxWeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationIntxWeightConfig, + ], + ) + def test_range_learning_convert_pass_qparams( + self, base_config_cls: Type[AOBaseConfig] + ): + """ + Verify that range learning QAT can pass qparams from the prepared + model to the convert model. + """ + group_size = 32 + config = IntxFakeQuantizeConfig( + torch.int4, + group_size=group_size, + is_symmetric=True, + is_dynamic=False, + range_learning=True, + ) + m = M() + example_inputs = m.example_inputs() + quantize_(m, QATConfig(weight_config=config, step="prepare")) + initialize_fake_quantizers(m, example_inputs) + + # convert and verify scales are what we expect + scale1 = m.linear1.weight_fake_quantizer.scale + scale2 = m.linear2.weight_fake_quantizer.scale + sub_scale = m.sub.linear.weight_fake_quantizer.scale + if base_config_cls == Int8DynamicActivationInt4WeightConfig: + base_config = base_config_cls() + quantize_(m, QATConfig(base_config, step="convert")) + torch.testing.assert_close( + m.linear1.weight.original_weight_tensor.tensor_impl.scale, scale1 + ) + torch.testing.assert_close( + m.linear2.weight.original_weight_tensor.tensor_impl.scale, scale2 + ) + torch.testing.assert_close( + m.sub.linear.weight.original_weight_tensor.tensor_impl.scale, sub_scale + ) + else: + base_config = base_config_cls(torch.int4, PerGroup(group_size)) + quantize_(m, QATConfig(base_config, step="convert")) + torch.testing.assert_close(m.linear1.weight.scale, scale1) + torch.testing.assert_close(m.linear2.weight.scale, scale2) + torch.testing.assert_close(m.sub.linear.weight.scale, sub_scale) + instantiate_parametrized_tests(TestQAT) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 92a2de316a..0d7ed8d9e2 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -245,6 +245,9 @@ def from_hp_to_intx( zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, _layout: Layout = PlainLayout(), use_hqq: bool = False, + *, + custom_scale: Optional[torch.Tensor] = None, + custom_zero_point: Optional[torch.Tensor] = None, ): """Convert a high precision tensor to an integer affine quantized tensor.""" original_shape = input_float.shape @@ -288,7 +291,13 @@ def from_hp_to_intx( ) data = data.to(target_dtype) else: - if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: + if custom_scale is None != custom_zero_point is None: + raise ValueError( + "`custom_scale` and `custom_zero_point` must be both defined or both None" + ) + if custom_scale is not None and custom_zero_point is not None: + scale, zero_point = custom_scale, custom_zero_point + elif zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero: scale, zero_point = _choose_qparams_affine_tinygemm( input_float, mapping_type, diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index 5bf1729f69..1287126bac 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -21,6 +21,7 @@ from .fake_quantize_config import ( FakeQuantizeConfig, # noqa: F401, for BC FakeQuantizeConfigBase, + IntxFakeQuantizeConfig, _infer_fake_quantize_configs, ) from .linear import FakeQuantizedLinear @@ -220,22 +221,41 @@ def _qat_config_transform( ) else: # Convert step + assert step == QATStep.CONVERT, "unexpected step '%s' in QATConfig" % step + assert config.activation_config is None, "unexpected `activation_config`" + assert config.weight_config is None, "unexpected `weight_config`" + + # Ignore unrelated modules + if not isinstance(module, (FakeQuantizedLinear, FakeQuantizedEmbedding)): + return module + + # Optionally pass custom scales and zero points to base config handler + # This is only for range learning and only applies to weights + kwargs = {} + weight_config = module.weight_fake_quantizer.config + if ( + isinstance(weight_config, IntxFakeQuantizeConfig) + and weight_config.range_learning + ): + kwargs["custom_scale"] = module.weight_fake_quantizer.scale + kwargs["custom_zero_point"] = module.weight_fake_quantizer.zero_point + # Swap FakeQuantizedLinear -> nn.Linear # Swap FakeQuantizedEmbedding -> nn.Embedding # Then apply the base config's transform function to quantize the model # If there is no base config, then simply perform the module swap - assert step == QATStep.CONVERT, "unexpected step '%s' in QATConfig" % step - assert config.activation_config is None, "unexpected `activation_config`" - assert config.weight_config is None, "unexpected `weight_config`" if isinstance(module, FakeQuantizedLinear): module = module.to_linear() elif isinstance(module, FakeQuantizedEmbedding): module = module.to_embedding() else: - # Unrelated module, ignore - return module + raise ValueError( + f"Encountered unexpected module {module}, should never happen" + ) if base_config is not None: - return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config) + return _QUANTIZE_CONFIG_HANDLER[type(base_config)]( + module, base_config, **kwargs + ) else: return module diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ef4b247819..68c0b42e76 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -647,7 +647,11 @@ def __post_init__(self): @register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig) def _int8_dynamic_activation_int4_weight_transform( - module: torch.nn.Module, config: Int8DynamicActivationInt4WeightConfig + module: torch.nn.Module, + config: Int8DynamicActivationInt4WeightConfig, + *, + custom_scale: Optional[torch.Tensor] = None, + custom_zero_point: Optional[torch.Tensor] = None, ): group_size = config.group_size layout = config.layout @@ -700,6 +704,8 @@ def _int8_dynamic_activation_int4_weight_transform( quant_min=0, quant_max=15, _layout=layout, + custom_scale=custom_scale, + custom_zero_point=custom_zero_point, ) else: weight = to_affine_quantized_intx( @@ -710,6 +716,8 @@ def _int8_dynamic_activation_int4_weight_transform( quant_min, quant_max, _layout=layout, + custom_scale=custom_scale, + custom_zero_point=custom_zero_point, ) weight = to_linear_activation_quantized(weight, input_quant_func) module.weight = torch.nn.Parameter(weight, requires_grad=False) @@ -809,7 +817,14 @@ def __post_init__(self): ) -def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config): +def _int8_dynamic_activation_intx_weight_quantize_tensor( + weight, + bias, + config, + *, + custom_scale: Optional[torch.Tensor] = None, + custom_zero_point: Optional[torch.Tensor] = None, +): weight_dtype = config.weight_dtype weight_granularity = config.weight_granularity weight_mapping_type = config.weight_mapping_type @@ -847,12 +862,16 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config): intx_packing_format == IntxPackingFormat.UNPACKED_TO_INT8 or intx_packing_format in opaque_formats ), f"Unsupported packing format: {intx_packing_format}" + if custom_zero_point is not None and custom_zero_point.dtype == torch.int32: + custom_zero_point = custom_zero_point.to(torch.int8) new_weight = IntxUnpackedToInt8Tensor.from_hp( weight, block_size, weight_dtype, mapping_type=weight_mapping_type, activation_quantization="int8_asym_per_token", + custom_scale=custom_scale, + custom_zero_point=custom_zero_point, ) if weight_scale_dtype is not None and weight_scale_dtype != weight.dtype: _adjust_scale_dtype_in_intx_unpacked_tensor( @@ -939,10 +958,18 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config): @register_quantize_module_handler(Int8DynamicActivationIntxWeightConfig) def _int8_dynamic_activation_intx_weight_transform( - module: torch.nn.Module, config: Int8DynamicActivationIntxWeightConfig + module: torch.nn.Module, + config: Int8DynamicActivationIntxWeightConfig, + *, + custom_scale: Optional[torch.Tensor] = None, + custom_zero_point: Optional[torch.Tensor] = None, ) -> torch.nn.Module: new_weight, new_bias = _int8_dynamic_activation_intx_weight_quantize_tensor( - module.weight, module.bias, config + module.weight, + module.bias, + config, + custom_scale=custom_scale, + custom_zero_point=custom_zero_point, ) module.weight = torch.nn.Parameter(new_weight, requires_grad=False) if new_bias is None: @@ -2177,7 +2204,13 @@ def __post_init__(self): ) -def _intx_weight_only_quantize_tensor(weight, config): +def _intx_weight_only_quantize_tensor( + weight, + config, + *, + custom_scale: Optional[torch.Tensor] = None, + custom_zero_point: Optional[torch.Tensor] = None, +): weight_dtype = config.weight_dtype granularity = config.granularity mapping_type = config.mapping_type @@ -2202,11 +2235,15 @@ def _intx_weight_only_quantize_tensor(weight, config): if config.version == 2: if config.intx_packing_format == IntxPackingFormat.UNPACKED_TO_INT8: + if custom_zero_point is not None and custom_zero_point.dtype == torch.int32: + custom_zero_point = custom_zero_point.to(torch.int8) new_weight = IntxUnpackedToInt8Tensor.from_hp( weight, block_size, weight_dtype, mapping_type=mapping_type, + custom_scale=custom_scale, + custom_zero_point=custom_zero_point, ) if scale_dtype is not None and scale_dtype != weight.dtype: _adjust_scale_dtype_in_intx_unpacked_tensor( @@ -2241,13 +2278,22 @@ def _intx_weight_only_quantize_tensor(weight, config): @register_quantize_module_handler(IntxWeightOnlyConfig) def _intx_weight_only_transform( - module: torch.nn.Module, config: IntxWeightOnlyConfig + module: torch.nn.Module, + config: IntxWeightOnlyConfig, + *, + custom_scale: Optional[torch.Tensor] = None, + custom_zero_point: Optional[torch.Tensor] = None, ) -> torch.nn.Module: assert hasattr(module, "weight"), ( "applying intx weight only quant requires module to have weight attribute" + " but {module} does not have one" ) - new_weight = _intx_weight_only_quantize_tensor(module.weight, config) + new_weight = _intx_weight_only_quantize_tensor( + module.weight, + config, + custom_scale=custom_scale, + custom_zero_point=custom_zero_point, + ) module.weight = torch.nn.Parameter(new_weight, requires_grad=False) if isinstance(module, nn.Linear): diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py index e9d79fc670..87402241dd 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py @@ -177,20 +177,29 @@ def from_hp( activation_quantization: Optional[ IntxUnpackedToInt8TensorActivationQuantization ] = None, + custom_scale: Optional[torch.Tensor] = None, + custom_zero_point: Optional[torch.Tensor] = None, ): """ Create an IntxUnpackedToInt8Tensor from a high-precision tensor """ qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype] - scale, zero_point = choose_qparams_affine( - hp_tensor, - mapping_type, - block_size, - target_dtype=torch.int8, - quant_min=qmin, - quant_max=qmax, - zero_point_dtype=torch.int8, - ) + if custom_scale is not None and custom_zero_point is not None: + scale, zero_point = custom_scale, custom_zero_point + elif custom_scale is None and custom_zero_point is None: + scale, zero_point = choose_qparams_affine( + hp_tensor, + mapping_type, + block_size, + target_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + zero_point_dtype=torch.int8, + ) + else: + raise ValueError( + "`custom_scale` and `custom_zero_point` must be both defined or both None" + ) qdata = quantize_affine( hp_tensor, block_size,