Skip to content

Commit adb966c

Browse files
committed
[draft] Pass QAT learned qparams in convert
**Summary:** Draft prototype to pass scales and zero points learned during QAT range learning to the PTQ base config. **Test Plan:** TBD
1 parent 9e5059e commit adb966c

File tree

3 files changed

+44
-14
lines changed

3 files changed

+44
-14
lines changed

torchao/quantization/qat/api.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .fake_quantize_config import (
2222
FakeQuantizeConfig, # noqa: F401, for BC
2323
FakeQuantizeConfigBase,
24+
IntxFakeQuantizeConfig,
2425
_infer_fake_quantize_configs,
2526
)
2627
from .linear import FakeQuantizedLinear
@@ -227,15 +228,26 @@ def _qat_config_transform(
227228
assert step == QATStep.CONVERT, "unexpected step '%s' in QATConfig" % step
228229
assert config.activation_config is None, "unexpected `activation_config`"
229230
assert config.weight_config is None, "unexpected `weight_config`"
231+
kwargs = {}
230232
if isinstance(module, FakeQuantizedLinear):
233+
# Optionally pass custom scales and zero points to base config handler
234+
weight_config = module.weight_fake_quantizer.config
235+
if (
236+
isinstance(weight_config, IntxFakeQuantizeConfig)
237+
and weight_config.range_learning
238+
):
239+
kwargs["custom_scale"] = weight_config.scale
240+
kwargs["custom_zero_point"] = weight_config.zero_point
231241
module = module.to_linear()
232242
elif isinstance(module, FakeQuantizedEmbedding):
233243
module = module.to_embedding()
234244
else:
235245
# Unrelated module, ignore
236246
return module
237247
if base_config is not None:
238-
return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config)
248+
return _QUANTIZE_CONFIG_HANDLER[type(base_config)](
249+
module, base_config, **kwargs
250+
)
239251
else:
240252
return module
241253

torchao/quantization/quant_api.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,9 @@ def __post_init__(self):
809809
)
810810

811811

812-
def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
812+
def _int8_dynamic_activation_intx_weight_quantize_tensor(
813+
weight, bias, config, **kwargs
814+
):
813815
weight_dtype = config.weight_dtype
814816
weight_granularity = config.weight_granularity
815817
weight_mapping_type = config.weight_mapping_type
@@ -853,6 +855,7 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
853855
weight_dtype,
854856
mapping_type=weight_mapping_type,
855857
activation_quantization="int8_asym_per_token",
858+
**kwargs,
856859
)
857860
if weight_scale_dtype is not None and weight_scale_dtype != weight.dtype:
858861
_adjust_scale_dtype_in_intx_unpacked_tensor(
@@ -939,10 +942,15 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
939942

940943
@register_quantize_module_handler(Int8DynamicActivationIntxWeightConfig)
941944
def _int8_dynamic_activation_intx_weight_transform(
942-
module: torch.nn.Module, config: Int8DynamicActivationIntxWeightConfig
945+
module: torch.nn.Module,
946+
config: Int8DynamicActivationIntxWeightConfig,
947+
**kwargs,
943948
) -> torch.nn.Module:
944949
new_weight, new_bias = _int8_dynamic_activation_intx_weight_quantize_tensor(
945-
module.weight, module.bias, config
950+
module.weight,
951+
module.bias,
952+
config,
953+
**kwargs,
946954
)
947955
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
948956
if new_bias is None:
@@ -2177,7 +2185,7 @@ def __post_init__(self):
21772185
)
21782186

21792187

2180-
def _intx_weight_only_quantize_tensor(weight, config):
2188+
def _intx_weight_only_quantize_tensor(weight, config, **kwargs):
21812189
weight_dtype = config.weight_dtype
21822190
granularity = config.granularity
21832191
mapping_type = config.mapping_type
@@ -2207,6 +2215,7 @@ def _intx_weight_only_quantize_tensor(weight, config):
22072215
block_size,
22082216
weight_dtype,
22092217
mapping_type=mapping_type,
2218+
**kwargs,
22102219
)
22112220
if scale_dtype is not None and scale_dtype != weight.dtype:
22122221
_adjust_scale_dtype_in_intx_unpacked_tensor(

torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,20 +177,29 @@ def from_hp(
177177
activation_quantization: Optional[
178178
IntxUnpackedToInt8TensorActivationQuantization
179179
] = None,
180+
custom_scale: Optional[torch.Tensor] = None,
181+
custom_zero_point: Optional[torch.Tensor] = None,
180182
):
181183
"""
182184
Create an IntxUnpackedToInt8Tensor from a high-precision tensor
183185
"""
184186
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype]
185-
scale, zero_point = choose_qparams_affine(
186-
hp_tensor,
187-
mapping_type,
188-
block_size,
189-
target_dtype=torch.int8,
190-
quant_min=qmin,
191-
quant_max=qmax,
192-
zero_point_dtype=torch.int8,
193-
)
187+
if custom_scale is not None and custom_zero_point is not None:
188+
scale, zero_point = custom_scale, custom_zero_point
189+
elif custom_scale is None and custom_zero_point is None:
190+
scale, zero_point = choose_qparams_affine(
191+
hp_tensor,
192+
mapping_type,
193+
block_size,
194+
target_dtype=torch.int8,
195+
quant_min=qmin,
196+
quant_max=qmax,
197+
zero_point_dtype=torch.int8,
198+
)
199+
else:
200+
raise ValueError(
201+
"`custom_scale` and `custom_zero_point` must be both defined or both None"
202+
)
194203
qdata = quantize_affine(
195204
hp_tensor,
196205
block_size,

0 commit comments

Comments
 (0)