Skip to content

Commit d8721d7

Browse files
committed
add static quant
1 parent 4e2f09c commit d8721d7

File tree

4 files changed

+106
-32
lines changed

4 files changed

+106
-32
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@
3131
version=2, granularity=PerTensor(), act_mapping_type=MappingType.ASYMMETRIC
3232
),
3333
Int8DynamicActivationInt8WeightConfig(
34-
version=2, granularity=PerRow(), act_mapping_type=MappingType.ASYMMETRIC
34+
version=2, granularity=PerTensor(), act_mapping_type=MappingType.SYMMETRIC
3535
),
3636
Int8DynamicActivationInt8WeightConfig(
37-
version=2, granularity=PerTensor(), act_mapping_type=MappingType.SYMMETRIC
37+
version=2, granularity=PerRow(), act_mapping_type=MappingType.ASYMMETRIC
3838
),
3939
Int8DynamicActivationInt8WeightConfig(
40-
version=2, granularity=PerRow(), act_mapping_type=MappingType.SYMMETRIC
40+
version=2, granularity=PerTensor(), act_mapping_type=MappingType.SYMMETRIC
4141
),
4242
]
4343

@@ -77,14 +77,6 @@ def test_creation_and_attributes(self, config):
7777
elif isinstance(config.granularity, PerTensor):
7878
self.assertEqual(w.scale.shape, (1, 1))
7979

80-
if config.act_mapping_type == MappingType.SYMMETRIC:
81-
self.assertEqual(w.zero_point, None)
82-
elif config.act_mapping_type == MappingType.ASYMMETRIC:
83-
if isinstance(config.granularity, PerRow):
84-
self.assertEqual(w.zero_point.shape, (w.shape[0], 1))
85-
elif isinstance(config.granularity, PerTensor):
86-
self.assertEqual(w.zero_point.shape, (1, 1))
87-
8880
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
8981
@common_utils.parametrize("compile", [True, False])
9082
@common_utils.parametrize("config", INT8_TEST_CONFIGS)

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
Int8DynamicActivationInt4WeightConfig,
6060
Int8DynamicActivationInt8WeightConfig,
6161
Int8DynamicActivationIntxWeightConfig,
62+
Int8StaticActivationInt8WeightConfig,
6263
Int8WeightOnlyConfig,
6364
IntxWeightOnlyConfig,
6465
ModuleFqnToConfig,
@@ -150,6 +151,7 @@
150151
"Int8DynamicActivationInt4WeightConfig",
151152
"Int8DynamicActivationInt8WeightConfig",
152153
"Int8DynamicActivationIntxWeightConfig",
154+
"Int8StaticActivationInt8WeightConfig",
153155
"Int4WeightOnlyConfig",
154156
"Float8DynamicActivationInt4WeightConfig",
155157
"Int8WeightOnlyConfig",

torchao/quantization/quant_api.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
IntxPackingFormat,
8989
IntxUnpackedToInt8Tensor,
9090
QuantizeTensorToFloat8Kwargs,
91+
QuantizeTensorToInt8Kwargs,
9192
)
9293
from torchao.quantization.transform_module import (
9394
_QUANTIZE_CONFIG_HANDLER,
@@ -1590,10 +1591,6 @@ def get_weight_block_size(x):
15901591
)
15911592
quantized_weight = to_linear_activation_quantized(new_weight, input_quant_func)
15921593
else:
1593-
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
1594-
QuantizeTensorToInt8Kwargs,
1595-
)
1596-
15971594
assert config.granularity in {PerRow(), PerTensor()}, (
15981595
"Only PerRow and PerTensor are supported"
15991596
)
@@ -1608,7 +1605,7 @@ def get_weight_block_size(x):
16081605
granularity=config.granularity,
16091606
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
16101607
granularity=act_granularity,
1611-
act_mapping_type=config.act_mapping_type,
1608+
mapping_type=config.act_mapping_type,
16121609
),
16131610
)
16141611

@@ -1617,7 +1614,10 @@ def get_weight_block_size(x):
16171614

16181615
@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig)
16191616
def _int8_dynamic_activation_int8_weight_transform(
1620-
module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig
1617+
module: torch.nn.Module,
1618+
config: Int8DynamicActivationInt8WeightConfig,
1619+
*,
1620+
parameter_name="weight",
16211621
) -> torch.nn.Module:
16221622
if config.set_inductor_config:
16231623
torchao.quantization.utils.recommended_inductor_config_setter()
@@ -1634,6 +1634,65 @@ def _int8_dynamic_activation_int8_weight_transform(
16341634
return module
16351635

16361636

1637+
@dataclass
1638+
class Int8StaticActivationInt8WeightConfig(AOBaseConfig):
1639+
"""
1640+
Configuration for applying float8 static symmetric quantization to
1641+
1642+
Args:
1643+
scale (torch.Tensor): The scale tensor for activation quantization.
1644+
activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m
1645+
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
1646+
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
1647+
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1648+
"""
1649+
1650+
scale: torch.Tensor
1651+
zero_point: Optional[torch.Tensor] = None
1652+
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
1653+
granularity: Optional[Union[Granularity, List[Granularity]]] = PerRow()
1654+
set_inductor_config: bool = True
1655+
version: int = 1
1656+
1657+
def __post_init__(self):
1658+
torch._C._log_api_usage_once(
1659+
"torchao.quantization.Int8StaticActivationInt8WeightConfig"
1660+
)
1661+
1662+
1663+
@register_quantize_module_handler(Int8StaticActivationInt8WeightConfig)
1664+
def _int8_static_activation_int8_weight_transform(
1665+
module: torch.nn.Module, config: Int8StaticActivationInt8WeightConfig
1666+
):
1667+
assert config.granularity in {PerRow(), PerTensor()}, (
1668+
"Only PerRow and PerTensor are supported"
1669+
)
1670+
1671+
if config.set_inductor_config:
1672+
torchao.quantization.utils.recommended_inductor_config_setter()
1673+
1674+
activation_granularity, weight_granularity = _normalize_granularity(
1675+
config.granularity
1676+
)
1677+
weight = module.weight
1678+
1679+
# TODO: Symmentric/Asymmetric choice for weight quantization
1680+
# https://github.com/pytorch/ao/pull/3241#discussion_r2551515539
1681+
quantized_weight = Int8Tensor.from_hp(
1682+
weight,
1683+
granularity=weight_granularity,
1684+
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
1685+
granularity=activation_granularity,
1686+
mapping_type=config.act_mapping_type,
1687+
scale=config.scale,
1688+
zero_point=config.zero_point,
1689+
),
1690+
)
1691+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
1692+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1693+
return module
1694+
1695+
16371696
def int8_dynamic_activation_int8_semi_sparse_weight():
16381697
"""
16391698
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight

torchao/quantization/quantize_/workflows/int8/int8_tensor.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs):
3737
"""
3838

3939
granularity: Granularity = PerRow()
40-
act_mapping_type: MappingType = MappingType.SYMMETRIC
40+
mapping_type: MappingType = MappingType.SYMMETRIC
41+
scale: Optional[torch.Tensor] = None
42+
zero_point: Optional[torch.Tensor] = None
4143

4244

4345
class Int8Tensor(TorchAOBaseTensor):
@@ -58,6 +60,7 @@ class Int8Tensor(TorchAOBaseTensor):
5860
tensor_data_names = ["qdata", "scale"]
5961
tensor_attribute_names = []
6062
optional_tensor_attribute_names = [
63+
"zero_point",
6164
"block_size",
6265
"act_quant_kwargs",
6366
"dtype",
@@ -67,6 +70,7 @@ def __new__(
6770
cls: type,
6871
qdata: torch.Tensor,
6972
scale: torch.Tensor,
73+
zero_point: Optional[torch.Tensor] = None,
7074
block_size: Optional[List[int]] = None,
7175
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
7276
dtype: Optional[torch.dtype] = None,
@@ -82,13 +86,15 @@ def __init__(
8286
self,
8387
qdata: torch.Tensor,
8488
scale: torch.Tensor,
89+
zero_point: Optional[torch.Tensor] = None,
8590
block_size: Optional[List[int]] = None,
8691
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
8792
dtype: Optional[torch.dtype] = None,
8893
):
8994
super().__init__()
9095
self.qdata = qdata
9196
self.scale = scale
97+
self.zero_point = zero_point
9298
self.block_size = block_size
9399
self.act_quant_kwargs = act_quant_kwargs
94100

@@ -98,6 +104,7 @@ def __repr__(self):
98104
f"act_quant_kwargs={self.act_quant_kwargs}, "
99105
f"qdata={self.qdata}, "
100106
f"scale={self.scale}, "
107+
f"zero_point={self.scale}, "
101108
f"block_size={self.block_size}, "
102109
f"shape={self.shape}, "
103110
f"device={self.device}, "
@@ -110,23 +117,30 @@ def from_hp(
110117
hp_tensor: torch.Tensor,
111118
granularity: Granularity = PerRow(),
112119
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
120+
scale: Optional[torch.Tensor] = None,
121+
zero_point: Optional[torch.Tensor] = None,
122+
mapping_type: MappingType = MappingType.SYMMETRIC,
113123
):
114124
"""Create Int8Tensor from high-precision tensor"""
115125
block_size = get_block_size(hp_tensor.shape, granularity)
116126
block_size = list(block_size)
117127

118-
scale, zero_point = choose_qparams_affine(
119-
input=hp_tensor,
120-
mapping_type=MappingType.SYMMETRIC,
121-
block_size=block_size,
122-
target_dtype=torch.int8,
123-
quant_min=-128,
124-
quant_max=127,
125-
scale_dtype=hp_tensor.dtype,
126-
zero_point_dtype=torch.int8,
127-
keepdim=True,
128-
)
129-
128+
# if scale and zero_point not given, then choose them dynamically
129+
if scale is None and zero_point is None:
130+
scale, zero_point = choose_qparams_affine(
131+
input=hp_tensor,
132+
mapping_type=mapping_type,
133+
block_size=block_size,
134+
target_dtype=torch.int8,
135+
quant_min=-128,
136+
quant_max=127,
137+
scale_dtype=hp_tensor.dtype,
138+
zero_point_dtype=torch.int8,
139+
keepdim=True,
140+
)
141+
142+
# if they are given, then use them to quantize
143+
# this is how we support static quantization
130144
int_data = quantize_affine(
131145
hp_tensor,
132146
block_size=block_size,
@@ -145,11 +159,14 @@ def from_hp(
145159

146160
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
147161
"""Dequantize int8 tensor to floating point"""
162+
zero_point = self.zero_point
163+
if zero_point is not None:
164+
zero_point = zero_point.squeeze()
148165
return dequantize_affine(
149166
input=self.qdata,
150167
block_size=self.block_size,
151168
scale=self.scale.squeeze(),
152-
zero_point=None,
169+
zero_point=zero_point,
153170
input_dtype=torch.int8,
154171
quant_min=-128,
155172
quant_max=127,
@@ -179,7 +196,11 @@ def _(func, types, args, kwargs):
179196

180197
if weight_tensor.act_quant_kwargs is not None:
181198
activation_tensor = Int8Tensor.from_hp(
182-
activation_tensor, weight_tensor.act_quant_kwargs.granularity
199+
activation_tensor,
200+
granularity=weight_tensor.act_quant_kwargs.granularity,
201+
mapping_type=weight_tensor.act_quant_kwargs.mapping_type,
202+
scale=weight_tensor.act_quant_kwargs.scale,
203+
zero_point=weight_tensor.act_quant_kwargs.zero_point,
183204
)
184205
# Dynamic activation quantization path
185206

0 commit comments

Comments
 (0)