Skip to content

Commit 2f0b504

Browse files
committed
[draft] Pass QAT learned qparams in convert
**Summary:** Draft prototype to pass scales and zero points learned during QAT range learning to the PTQ base config. **Test Plan:** TBD
1 parent 9e5059e commit 2f0b504

File tree

3 files changed

+27
-14
lines changed

3 files changed

+27
-14
lines changed

torchao/quantization/qat/api.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,21 @@ def _qat_config_transform(
227227
assert step == QATStep.CONVERT, "unexpected step '%s' in QATConfig" % step
228228
assert config.activation_config is None, "unexpected `activation_config`"
229229
assert config.weight_config is None, "unexpected `weight_config`"
230+
kwargs = {}
230231
if isinstance(module, FakeQuantizedLinear):
232+
# Optionally pass custom scales and zero points to base config handler
233+
weight_config = module.weight_fake_quantizer.config
234+
if isinstance(weight_config, IntxFakeQuantizeConfig) and weight_config.range_learning:
235+
kwargs["custom_scales"] = weight_config.scale
236+
kwargs["custom_zero_point"] = weight_config.zero_point
231237
module = module.to_linear()
232238
elif isinstance(module, FakeQuantizedEmbedding):
233239
module = module.to_embedding()
234240
else:
235241
# Unrelated module, ignore
236242
return module
237243
if base_config is not None:
238-
return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config)
244+
return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config, **kwargs)
239245
else:
240246
return module
241247

torchao/quantization/quant_api.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@ def __post_init__(self):
809809
)
810810

811811

812-
def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
812+
def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config, **kwargs):
813813
weight_dtype = config.weight_dtype
814814
weight_granularity = config.weight_granularity
815815
weight_mapping_type = config.weight_mapping_type
@@ -853,6 +853,7 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
853853
weight_dtype,
854854
mapping_type=weight_mapping_type,
855855
activation_quantization="int8_asym_per_token",
856+
**kwargs,
856857
)
857858
if weight_scale_dtype is not None and weight_scale_dtype != weight.dtype:
858859
_adjust_scale_dtype_in_intx_unpacked_tensor(
@@ -939,10 +940,10 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
939940

940941
@register_quantize_module_handler(Int8DynamicActivationIntxWeightConfig)
941942
def _int8_dynamic_activation_intx_weight_transform(
942-
module: torch.nn.Module, config: Int8DynamicActivationIntxWeightConfig
943+
module: torch.nn.Module, config: Int8DynamicActivationIntxWeightConfig, **kwargs,
943944
) -> torch.nn.Module:
944945
new_weight, new_bias = _int8_dynamic_activation_intx_weight_quantize_tensor(
945-
module.weight, module.bias, config
946+
module.weight, module.bias, config, **kwargs,
946947
)
947948
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
948949
if new_bias is None:
@@ -2177,7 +2178,7 @@ def __post_init__(self):
21772178
)
21782179

21792180

2180-
def _intx_weight_only_quantize_tensor(weight, config):
2181+
def _intx_weight_only_quantize_tensor(weight, config, **kwargs):
21812182
weight_dtype = config.weight_dtype
21822183
granularity = config.granularity
21832184
mapping_type = config.mapping_type
@@ -2207,6 +2208,7 @@ def _intx_weight_only_quantize_tensor(weight, config):
22072208
block_size,
22082209
weight_dtype,
22092210
mapping_type=mapping_type,
2211+
**kwargs,
22102212
)
22112213
if scale_dtype is not None and scale_dtype != weight.dtype:
22122214
_adjust_scale_dtype_in_intx_unpacked_tensor(

torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,20 +177,25 @@ def from_hp(
177177
activation_quantization: Optional[
178178
IntxUnpackedToInt8TensorActivationQuantization
179179
] = None,
180+
custom_scale: Optional[torch.Tensor] = None,
181+
custom_zero_point: Optional[torch.Tensor] = None,
180182
):
181183
"""
182184
Create an IntxUnpackedToInt8Tensor from a high-precision tensor
183185
"""
184186
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype]
185-
scale, zero_point = choose_qparams_affine(
186-
hp_tensor,
187-
mapping_type,
188-
block_size,
189-
target_dtype=torch.int8,
190-
quant_min=qmin,
191-
quant_max=qmax,
192-
zero_point_dtype=torch.int8,
193-
)
187+
if custom_scale is None or custom_zero_point is None:
188+
scale, zero_point = choose_qparams_affine(
189+
hp_tensor,
190+
mapping_type,
191+
block_size,
192+
target_dtype=torch.int8,
193+
quant_min=qmin,
194+
quant_max=qmax,
195+
zero_point_dtype=torch.int8,
196+
)
197+
else:
198+
scale, zero_point = custom_scale, custom_zero_point
194199
qdata = quantize_affine(
195200
hp_tensor,
196201
block_size,

0 commit comments

Comments
 (0)