Skip to content

Commit ae12e42

Browse files
authored
Pass QAT learned qparams in convert (#3022)
**Summary:** Add support to pass scales and zero points learned during QAT range learning to the PTQ base config. Currently only the following configs support this feature: ``` IntxWeightOnlyConfig Int8DynamicActivationInt4WeightConfig Int8DynamicActivationIntxWeightConfig ``` During the convert phase, QAT will detect if range learning was used during training, and pass the learned scales and zero points as custom qparams to the quantized tensor subclass, so PTQ will produce more consistent numerics. Fixes part of #2271. **Test Plan:** ``` python test/quantization/test_qat.py -k test_range_learning_convert_pass_qparams ```
1 parent f35dcd7 commit ae12e42

File tree

5 files changed

+159
-24
lines changed

5 files changed

+159
-24
lines changed

test/quantization/test_qat.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import copy
1111
import unittest
1212
import warnings
13-
from typing import List
13+
from typing import List, Type
1414

1515
import torch
1616
import torch.nn.functional as F
@@ -2304,6 +2304,57 @@ def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
23042304
rtol=0,
23052305
)
23062306

2307+
@parametrize(
2308+
"base_config_cls",
2309+
[
2310+
IntxWeightOnlyConfig,
2311+
Int8DynamicActivationInt4WeightConfig,
2312+
Int8DynamicActivationIntxWeightConfig,
2313+
],
2314+
)
2315+
def test_range_learning_convert_pass_qparams(
2316+
self, base_config_cls: Type[AOBaseConfig]
2317+
):
2318+
"""
2319+
Verify that range learning QAT can pass qparams from the prepared
2320+
model to the convert model.
2321+
"""
2322+
group_size = 32
2323+
config = IntxFakeQuantizeConfig(
2324+
torch.int4,
2325+
group_size=group_size,
2326+
is_symmetric=True,
2327+
is_dynamic=False,
2328+
range_learning=True,
2329+
)
2330+
m = M()
2331+
example_inputs = m.example_inputs()
2332+
quantize_(m, QATConfig(weight_config=config, step="prepare"))
2333+
initialize_fake_quantizers(m, example_inputs)
2334+
2335+
# convert and verify scales are what we expect
2336+
scale1 = m.linear1.weight_fake_quantizer.scale
2337+
scale2 = m.linear2.weight_fake_quantizer.scale
2338+
sub_scale = m.sub.linear.weight_fake_quantizer.scale
2339+
if base_config_cls == Int8DynamicActivationInt4WeightConfig:
2340+
base_config = base_config_cls()
2341+
quantize_(m, QATConfig(base_config, step="convert"))
2342+
torch.testing.assert_close(
2343+
m.linear1.weight.original_weight_tensor.tensor_impl.scale, scale1
2344+
)
2345+
torch.testing.assert_close(
2346+
m.linear2.weight.original_weight_tensor.tensor_impl.scale, scale2
2347+
)
2348+
torch.testing.assert_close(
2349+
m.sub.linear.weight.original_weight_tensor.tensor_impl.scale, sub_scale
2350+
)
2351+
else:
2352+
base_config = base_config_cls(torch.int4, PerGroup(group_size))
2353+
quantize_(m, QATConfig(base_config, step="convert"))
2354+
torch.testing.assert_close(m.linear1.weight.scale, scale1)
2355+
torch.testing.assert_close(m.linear2.weight.scale, scale2)
2356+
torch.testing.assert_close(m.sub.linear.weight.scale, sub_scale)
2357+
23072358

23082359
instantiate_parametrized_tests(TestQAT)
23092360

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ def from_hp_to_intx(
245245
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
246246
_layout: Layout = PlainLayout(),
247247
use_hqq: bool = False,
248+
*,
249+
custom_scale: Optional[torch.Tensor] = None,
250+
custom_zero_point: Optional[torch.Tensor] = None,
248251
):
249252
"""Convert a high precision tensor to an integer affine quantized tensor."""
250253
original_shape = input_float.shape
@@ -288,7 +291,13 @@ def from_hp_to_intx(
288291
)
289292
data = data.to(target_dtype)
290293
else:
291-
if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero:
294+
if custom_scale is None != custom_zero_point is None:
295+
raise ValueError(
296+
"`custom_scale` and `custom_zero_point` must be both defined or both None"
297+
)
298+
if custom_scale is not None and custom_zero_point is not None:
299+
scale, zero_point = custom_scale, custom_zero_point
300+
elif zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero:
292301
scale, zero_point = _choose_qparams_affine_tinygemm(
293302
input_float,
294303
mapping_type,

torchao/quantization/qat/api.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .fake_quantize_config import (
2222
FakeQuantizeConfig, # noqa: F401, for BC
2323
FakeQuantizeConfigBase,
24+
IntxFakeQuantizeConfig,
2425
_infer_fake_quantize_configs,
2526
)
2627
from .linear import FakeQuantizedLinear
@@ -220,22 +221,41 @@ def _qat_config_transform(
220221
)
221222
else:
222223
# Convert step
224+
assert step == QATStep.CONVERT, "unexpected step '%s' in QATConfig" % step
225+
assert config.activation_config is None, "unexpected `activation_config`"
226+
assert config.weight_config is None, "unexpected `weight_config`"
227+
228+
# Ignore unrelated modules
229+
if not isinstance(module, (FakeQuantizedLinear, FakeQuantizedEmbedding)):
230+
return module
231+
232+
# Optionally pass custom scales and zero points to base config handler
233+
# This is only for range learning and only applies to weights
234+
kwargs = {}
235+
weight_config = module.weight_fake_quantizer.config
236+
if (
237+
isinstance(weight_config, IntxFakeQuantizeConfig)
238+
and weight_config.range_learning
239+
):
240+
kwargs["custom_scale"] = module.weight_fake_quantizer.scale
241+
kwargs["custom_zero_point"] = module.weight_fake_quantizer.zero_point
242+
223243
# Swap FakeQuantizedLinear -> nn.Linear
224244
# Swap FakeQuantizedEmbedding -> nn.Embedding
225245
# Then apply the base config's transform function to quantize the model
226246
# If there is no base config, then simply perform the module swap
227-
assert step == QATStep.CONVERT, "unexpected step '%s' in QATConfig" % step
228-
assert config.activation_config is None, "unexpected `activation_config`"
229-
assert config.weight_config is None, "unexpected `weight_config`"
230247
if isinstance(module, FakeQuantizedLinear):
231248
module = module.to_linear()
232249
elif isinstance(module, FakeQuantizedEmbedding):
233250
module = module.to_embedding()
234251
else:
235-
# Unrelated module, ignore
236-
return module
252+
raise ValueError(
253+
f"Encountered unexpected module {module}, should never happen"
254+
)
237255
if base_config is not None:
238-
return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config)
256+
return _QUANTIZE_CONFIG_HANDLER[type(base_config)](
257+
module, base_config, **kwargs
258+
)
239259
else:
240260
return module
241261

torchao/quantization/quant_api.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,11 @@ def __post_init__(self):
644644

645645
@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig)
646646
def _int8_dynamic_activation_int4_weight_transform(
647-
module: torch.nn.Module, config: Int8DynamicActivationInt4WeightConfig
647+
module: torch.nn.Module,
648+
config: Int8DynamicActivationInt4WeightConfig,
649+
*,
650+
custom_scale: Optional[torch.Tensor] = None,
651+
custom_zero_point: Optional[torch.Tensor] = None,
648652
):
649653
group_size = config.group_size
650654
layout = config.layout
@@ -697,6 +701,8 @@ def _int8_dynamic_activation_int4_weight_transform(
697701
quant_min=0,
698702
quant_max=15,
699703
_layout=layout,
704+
custom_scale=custom_scale,
705+
custom_zero_point=custom_zero_point,
700706
)
701707
else:
702708
weight = to_affine_quantized_intx(
@@ -707,6 +713,8 @@ def _int8_dynamic_activation_int4_weight_transform(
707713
quant_min,
708714
quant_max,
709715
_layout=layout,
716+
custom_scale=custom_scale,
717+
custom_zero_point=custom_zero_point,
710718
)
711719
weight = to_linear_activation_quantized(weight, input_quant_func)
712720
module.weight = torch.nn.Parameter(weight, requires_grad=False)
@@ -806,7 +814,14 @@ def __post_init__(self):
806814
)
807815

808816

809-
def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
817+
def _int8_dynamic_activation_intx_weight_quantize_tensor(
818+
weight,
819+
bias,
820+
config,
821+
*,
822+
custom_scale: Optional[torch.Tensor] = None,
823+
custom_zero_point: Optional[torch.Tensor] = None,
824+
):
810825
weight_dtype = config.weight_dtype
811826
weight_granularity = config.weight_granularity
812827
weight_mapping_type = config.weight_mapping_type
@@ -844,12 +859,16 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
844859
intx_packing_format == IntxPackingFormat.UNPACKED_TO_INT8
845860
or intx_packing_format in opaque_formats
846861
), f"Unsupported packing format: {intx_packing_format}"
862+
if custom_zero_point is not None and custom_zero_point.dtype == torch.int32:
863+
custom_zero_point = custom_zero_point.to(torch.int8)
847864
new_weight = IntxUnpackedToInt8Tensor.from_hp(
848865
weight,
849866
block_size,
850867
weight_dtype,
851868
mapping_type=weight_mapping_type,
852869
activation_quantization="int8_asym_per_token",
870+
custom_scale=custom_scale,
871+
custom_zero_point=custom_zero_point,
853872
)
854873
if weight_scale_dtype is not None and weight_scale_dtype != weight.dtype:
855874
_adjust_scale_dtype_in_intx_unpacked_tensor(
@@ -936,10 +955,18 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
936955

937956
@register_quantize_module_handler(Int8DynamicActivationIntxWeightConfig)
938957
def _int8_dynamic_activation_intx_weight_transform(
939-
module: torch.nn.Module, config: Int8DynamicActivationIntxWeightConfig
958+
module: torch.nn.Module,
959+
config: Int8DynamicActivationIntxWeightConfig,
960+
*,
961+
custom_scale: Optional[torch.Tensor] = None,
962+
custom_zero_point: Optional[torch.Tensor] = None,
940963
) -> torch.nn.Module:
941964
new_weight, new_bias = _int8_dynamic_activation_intx_weight_quantize_tensor(
942-
module.weight, module.bias, config
965+
module.weight,
966+
module.bias,
967+
config,
968+
custom_scale=custom_scale,
969+
custom_zero_point=custom_zero_point,
943970
)
944971
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
945972
if new_bias is None:
@@ -2179,7 +2206,13 @@ def __post_init__(self):
21792206
)
21802207

21812208

2182-
def _intx_weight_only_quantize_tensor(weight, config):
2209+
def _intx_weight_only_quantize_tensor(
2210+
weight,
2211+
config,
2212+
*,
2213+
custom_scale: Optional[torch.Tensor] = None,
2214+
custom_zero_point: Optional[torch.Tensor] = None,
2215+
):
21832216
weight_dtype = config.weight_dtype
21842217
granularity = config.granularity
21852218
mapping_type = config.mapping_type
@@ -2204,11 +2237,15 @@ def _intx_weight_only_quantize_tensor(weight, config):
22042237

22052238
if config.version == 2:
22062239
if config.intx_packing_format == IntxPackingFormat.UNPACKED_TO_INT8:
2240+
if custom_zero_point is not None and custom_zero_point.dtype == torch.int32:
2241+
custom_zero_point = custom_zero_point.to(torch.int8)
22072242
new_weight = IntxUnpackedToInt8Tensor.from_hp(
22082243
weight,
22092244
block_size,
22102245
weight_dtype,
22112246
mapping_type=mapping_type,
2247+
custom_scale=custom_scale,
2248+
custom_zero_point=custom_zero_point,
22122249
)
22132250
if scale_dtype is not None and scale_dtype != weight.dtype:
22142251
_adjust_scale_dtype_in_intx_unpacked_tensor(
@@ -2243,13 +2280,22 @@ def _intx_weight_only_quantize_tensor(weight, config):
22432280

22442281
@register_quantize_module_handler(IntxWeightOnlyConfig)
22452282
def _intx_weight_only_transform(
2246-
module: torch.nn.Module, config: IntxWeightOnlyConfig
2283+
module: torch.nn.Module,
2284+
config: IntxWeightOnlyConfig,
2285+
*,
2286+
custom_scale: Optional[torch.Tensor] = None,
2287+
custom_zero_point: Optional[torch.Tensor] = None,
22472288
) -> torch.nn.Module:
22482289
assert hasattr(module, "weight"), (
22492290
"applying intx weight only quant requires module to have weight attribute"
22502291
+ " but {module} does not have one"
22512292
)
2252-
new_weight = _intx_weight_only_quantize_tensor(module.weight, config)
2293+
new_weight = _intx_weight_only_quantize_tensor(
2294+
module.weight,
2295+
config,
2296+
custom_scale=custom_scale,
2297+
custom_zero_point=custom_zero_point,
2298+
)
22532299
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
22542300

22552301
if isinstance(module, nn.Linear):

torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,20 +177,29 @@ def from_hp(
177177
activation_quantization: Optional[
178178
IntxUnpackedToInt8TensorActivationQuantization
179179
] = None,
180+
custom_scale: Optional[torch.Tensor] = None,
181+
custom_zero_point: Optional[torch.Tensor] = None,
180182
):
181183
"""
182184
Create an IntxUnpackedToInt8Tensor from a high-precision tensor
183185
"""
184186
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype]
185-
scale, zero_point = choose_qparams_affine(
186-
hp_tensor,
187-
mapping_type,
188-
block_size,
189-
target_dtype=torch.int8,
190-
quant_min=qmin,
191-
quant_max=qmax,
192-
zero_point_dtype=torch.int8,
193-
)
187+
if custom_scale is not None and custom_zero_point is not None:
188+
scale, zero_point = custom_scale, custom_zero_point
189+
elif custom_scale is None and custom_zero_point is None:
190+
scale, zero_point = choose_qparams_affine(
191+
hp_tensor,
192+
mapping_type,
193+
block_size,
194+
target_dtype=torch.int8,
195+
quant_min=qmin,
196+
quant_max=qmax,
197+
zero_point_dtype=torch.int8,
198+
)
199+
else:
200+
raise ValueError(
201+
"`custom_scale` and `custom_zero_point` must be both defined or both None"
202+
)
194203
qdata = quantize_affine(
195204
hp_tensor,
196205
block_size,

0 commit comments

Comments
 (0)