Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import copy
import unittest
import warnings
from typing import List
from typing import List, Type

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -2304,6 +2304,57 @@ def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
rtol=0,
)

@parametrize(
"base_config_cls",
[
IntxWeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationIntxWeightConfig,
],
)
def test_range_learning_convert_pass_qparams(
self, base_config_cls: Type[AOBaseConfig]
):
"""
Verify that range learning QAT can pass qparams from the prepared
model to the convert model.
"""
group_size = 32
config = IntxFakeQuantizeConfig(
torch.int4,
group_size=group_size,
is_symmetric=True,
is_dynamic=False,
range_learning=True,
)
m = M()
example_inputs = m.example_inputs()
quantize_(m, QATConfig(weight_config=config, step="prepare"))
initialize_fake_quantizers(m, example_inputs)

# convert and verify scales are what we expect
scale1 = m.linear1.weight_fake_quantizer.scale
scale2 = m.linear2.weight_fake_quantizer.scale
sub_scale = m.sub.linear.weight_fake_quantizer.scale
if base_config_cls == Int8DynamicActivationInt4WeightConfig:
base_config = base_config_cls()
quantize_(m, QATConfig(base_config, step="convert"))
torch.testing.assert_close(
m.linear1.weight.original_weight_tensor.tensor_impl.scale, scale1
)
torch.testing.assert_close(
m.linear2.weight.original_weight_tensor.tensor_impl.scale, scale2
)
torch.testing.assert_close(
m.sub.linear.weight.original_weight_tensor.tensor_impl.scale, sub_scale
)
else:
base_config = base_config_cls(torch.int4, PerGroup(group_size))
quantize_(m, QATConfig(base_config, step="convert"))
torch.testing.assert_close(m.linear1.weight.scale, scale1)
torch.testing.assert_close(m.linear2.weight.scale, scale2)
torch.testing.assert_close(m.sub.linear.weight.scale, sub_scale)


instantiate_parametrized_tests(TestQAT)

Expand Down
11 changes: 10 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ def from_hp_to_intx(
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
_layout: Layout = PlainLayout(),
use_hqq: bool = False,
*,
custom_scale: Optional[torch.Tensor] = None,
custom_zero_point: Optional[torch.Tensor] = None,
):
"""Convert a high precision tensor to an integer affine quantized tensor."""
original_shape = input_float.shape
Expand Down Expand Up @@ -288,7 +291,13 @@ def from_hp_to_intx(
)
data = data.to(target_dtype)
else:
if zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero:
if custom_scale is None != custom_zero_point is None:
raise ValueError(
"`custom_scale` and `custom_zero_point` must be both defined or both None"
)
if custom_scale is not None and custom_zero_point is not None:
scale, zero_point = custom_scale, custom_zero_point
elif zero_point_domain == ZeroPointDomain.FLOAT and not preserve_zero:
scale, zero_point = _choose_qparams_affine_tinygemm(
input_float,
mapping_type,
Expand Down
32 changes: 26 additions & 6 deletions torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .fake_quantize_config import (
FakeQuantizeConfig, # noqa: F401, for BC
FakeQuantizeConfigBase,
IntxFakeQuantizeConfig,
_infer_fake_quantize_configs,
)
from .linear import FakeQuantizedLinear
Expand Down Expand Up @@ -220,22 +221,41 @@ def _qat_config_transform(
)
else:
# Convert step
assert step == QATStep.CONVERT, "unexpected step '%s' in QATConfig" % step
assert config.activation_config is None, "unexpected `activation_config`"
assert config.weight_config is None, "unexpected `weight_config`"

# Ignore unrelated modules
if not isinstance(module, (FakeQuantizedLinear, FakeQuantizedEmbedding)):
return module

# Optionally pass custom scales and zero points to base config handler
# This is only for range learning and only applies to weights
kwargs = {}
weight_config = module.weight_fake_quantizer.config
if (
isinstance(weight_config, IntxFakeQuantizeConfig)
and weight_config.range_learning
):
kwargs["custom_scale"] = module.weight_fake_quantizer.scale
kwargs["custom_zero_point"] = module.weight_fake_quantizer.zero_point

# Swap FakeQuantizedLinear -> nn.Linear
# Swap FakeQuantizedEmbedding -> nn.Embedding
# Then apply the base config's transform function to quantize the model
# If there is no base config, then simply perform the module swap
assert step == QATStep.CONVERT, "unexpected step '%s' in QATConfig" % step
assert config.activation_config is None, "unexpected `activation_config`"
assert config.weight_config is None, "unexpected `weight_config`"
if isinstance(module, FakeQuantizedLinear):
module = module.to_linear()
elif isinstance(module, FakeQuantizedEmbedding):
module = module.to_embedding()
else:
# Unrelated module, ignore
return module
raise ValueError(
f"Encountered unexpected module {module}, should never happen"
)
if base_config is not None:
return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config)
return _QUANTIZE_CONFIG_HANDLER[type(base_config)](
module, base_config, **kwargs
)
else:
return module

Expand Down
60 changes: 53 additions & 7 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,11 @@ def __post_init__(self):

@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig)
def _int8_dynamic_activation_int4_weight_transform(
module: torch.nn.Module, config: Int8DynamicActivationInt4WeightConfig
module: torch.nn.Module,
config: Int8DynamicActivationInt4WeightConfig,
*,
custom_scale: Optional[torch.Tensor] = None,
custom_zero_point: Optional[torch.Tensor] = None,
):
group_size = config.group_size
layout = config.layout
Expand Down Expand Up @@ -700,6 +704,8 @@ def _int8_dynamic_activation_int4_weight_transform(
quant_min=0,
quant_max=15,
_layout=layout,
custom_scale=custom_scale,
custom_zero_point=custom_zero_point,
)
else:
weight = to_affine_quantized_intx(
Expand All @@ -710,6 +716,8 @@ def _int8_dynamic_activation_int4_weight_transform(
quant_min,
quant_max,
_layout=layout,
custom_scale=custom_scale,
custom_zero_point=custom_zero_point,
)
weight = to_linear_activation_quantized(weight, input_quant_func)
module.weight = torch.nn.Parameter(weight, requires_grad=False)
Expand Down Expand Up @@ -809,7 +817,14 @@ def __post_init__(self):
)


def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
def _int8_dynamic_activation_intx_weight_quantize_tensor(
weight,
bias,
config,
*,
custom_scale: Optional[torch.Tensor] = None,
custom_zero_point: Optional[torch.Tensor] = None,
):
weight_dtype = config.weight_dtype
weight_granularity = config.weight_granularity
weight_mapping_type = config.weight_mapping_type
Expand Down Expand Up @@ -847,12 +862,16 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
intx_packing_format == IntxPackingFormat.UNPACKED_TO_INT8
or intx_packing_format in opaque_formats
), f"Unsupported packing format: {intx_packing_format}"
if custom_zero_point is not None and custom_zero_point.dtype == torch.int32:
custom_zero_point = custom_zero_point.to(torch.int8)
new_weight = IntxUnpackedToInt8Tensor.from_hp(
weight,
block_size,
weight_dtype,
mapping_type=weight_mapping_type,
activation_quantization="int8_asym_per_token",
custom_scale=custom_scale,
custom_zero_point=custom_zero_point,
)
if weight_scale_dtype is not None and weight_scale_dtype != weight.dtype:
_adjust_scale_dtype_in_intx_unpacked_tensor(
Expand Down Expand Up @@ -939,10 +958,18 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):

@register_quantize_module_handler(Int8DynamicActivationIntxWeightConfig)
def _int8_dynamic_activation_intx_weight_transform(
module: torch.nn.Module, config: Int8DynamicActivationIntxWeightConfig
module: torch.nn.Module,
config: Int8DynamicActivationIntxWeightConfig,
*,
custom_scale: Optional[torch.Tensor] = None,
custom_zero_point: Optional[torch.Tensor] = None,
) -> torch.nn.Module:
new_weight, new_bias = _int8_dynamic_activation_intx_weight_quantize_tensor(
module.weight, module.bias, config
module.weight,
module.bias,
config,
custom_scale=custom_scale,
custom_zero_point=custom_zero_point,
)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
if new_bias is None:
Expand Down Expand Up @@ -2177,7 +2204,13 @@ def __post_init__(self):
)


def _intx_weight_only_quantize_tensor(weight, config):
def _intx_weight_only_quantize_tensor(
weight,
config,
*,
custom_scale: Optional[torch.Tensor] = None,
custom_zero_point: Optional[torch.Tensor] = None,
):
weight_dtype = config.weight_dtype
granularity = config.granularity
mapping_type = config.mapping_type
Expand All @@ -2202,11 +2235,15 @@ def _intx_weight_only_quantize_tensor(weight, config):

if config.version == 2:
if config.intx_packing_format == IntxPackingFormat.UNPACKED_TO_INT8:
if custom_zero_point is not None and custom_zero_point.dtype == torch.int32:
custom_zero_point = custom_zero_point.to(torch.int8)
new_weight = IntxUnpackedToInt8Tensor.from_hp(
weight,
block_size,
weight_dtype,
mapping_type=mapping_type,
custom_scale=custom_scale,
custom_zero_point=custom_zero_point,
)
if scale_dtype is not None and scale_dtype != weight.dtype:
_adjust_scale_dtype_in_intx_unpacked_tensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably you don't want this to run if you passed custom scales?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into this a bit and I think it's actually fine to run this? Basically we use the custom scale first and then cast it to the configured scale dtype (otherwise custom_scale will have to leak into this file, which I hope to avoid)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right

Expand Down Expand Up @@ -2241,13 +2278,22 @@ def _intx_weight_only_quantize_tensor(weight, config):

@register_quantize_module_handler(IntxWeightOnlyConfig)
def _intx_weight_only_transform(
module: torch.nn.Module, config: IntxWeightOnlyConfig
module: torch.nn.Module,
config: IntxWeightOnlyConfig,
*,
custom_scale: Optional[torch.Tensor] = None,
custom_zero_point: Optional[torch.Tensor] = None,
) -> torch.nn.Module:
assert hasattr(module, "weight"), (
"applying intx weight only quant requires module to have weight attribute"
+ " but {module} does not have one"
)
new_weight = _intx_weight_only_quantize_tensor(module.weight, config)
new_weight = _intx_weight_only_quantize_tensor(
module.weight,
config,
custom_scale=custom_scale,
custom_zero_point=custom_zero_point,
)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)

if isinstance(module, nn.Linear):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,20 +177,29 @@ def from_hp(
activation_quantization: Optional[
IntxUnpackedToInt8TensorActivationQuantization
] = None,
custom_scale: Optional[torch.Tensor] = None,
custom_zero_point: Optional[torch.Tensor] = None,
):
"""
Create an IntxUnpackedToInt8Tensor from a high-precision tensor
"""
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype]
scale, zero_point = choose_qparams_affine(
hp_tensor,
mapping_type,
block_size,
target_dtype=torch.int8,
quant_min=qmin,
quant_max=qmax,
zero_point_dtype=torch.int8,
)
if custom_scale is not None and custom_zero_point is not None:
scale, zero_point = custom_scale, custom_zero_point
elif custom_scale is None and custom_zero_point is None:
scale, zero_point = choose_qparams_affine(
hp_tensor,
mapping_type,
block_size,
target_dtype=torch.int8,
quant_min=qmin,
quant_max=qmax,
zero_point_dtype=torch.int8,
)
else:
raise ValueError(
"`custom_scale` and `custom_zero_point` must be both defined or both None"
)
qdata = quantize_affine(
hp_tensor,
block_size,
Expand Down
Loading