Skip to content

Commit 1591603

Browse files
authored
Support Int4OpaqueTensor for HQQ (#3028)
* Support Int4OpaqueTensor for HQQ Make Int4OpaqueTensor support HQQ. Signed-off-by: Cui, Lily <lily.cui@intel.com> * Format codes Signed-off-by: Cui, Lily <lily.cui@intel.com> --------- Signed-off-by: Cui, Lily <lily.cui@intel.com>
1 parent a951643 commit 1591603

File tree

3 files changed

+74
-32
lines changed

3 files changed

+74
-32
lines changed

test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@
2626
)
2727

2828

29-
def get_config(group_size):
29+
def get_config(group_size, use_hqq):
3030
return Int4WeightOnlyConfig(
3131
group_size=group_size,
3232
int4_packing_format="opaque",
33+
int4_choose_qparams_algorithm="hqq" if use_hqq else "tinygemm",
3334
)
3435

3536

@@ -45,13 +46,14 @@ class TestInt4OpaqueTensor(TestCase):
4546
)
4647
@parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
4748
@parametrize("group_size", [32, 64, 128])
48-
def test_linear(self, sizes, dtype, group_size):
49+
@parametrize("use_hqq", [True, False])
50+
def test_linear(self, sizes, dtype, group_size, use_hqq):
4951
device = "cpu"
5052
M, N, K = sizes
5153
input = torch.randn(*M, K, dtype=dtype, device=device)
5254
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
5355
original = linear(input)
54-
quantize_(linear, get_config(group_size))
56+
quantize_(linear, get_config(group_size, use_hqq))
5557
quantized = linear(input)
5658
self.assertTrue(compute_error(original, quantized) > 20)
5759

@@ -60,9 +62,10 @@ def test_linear(self, sizes, dtype, group_size):
6062
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)
6163

6264
@parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
63-
def test_module_path(self, dtype):
65+
@parametrize("use_hqq", [True, False])
66+
def test_module_path(self, dtype, use_hqq):
6467
linear = torch.nn.Linear(128, 256, dtype=dtype)
65-
quantize_(linear, get_config(group_size=128))
68+
quantize_(linear, get_config(group_size=128, use_hqq=use_hqq))
6669
self.assertEqual(
6770
str(type(linear.weight)),
6871
"<class 'torchao.quantization.Int4OpaqueTensor'>",
@@ -77,12 +80,13 @@ def test_module_path(self, dtype):
7780
"<class 'torchao.quantization.Int4OpaqueTensor'>",
7881
)
7982

80-
def test_activation_prescaling(self):
83+
@parametrize("use_hqq", [True, False])
84+
def test_activation_prescaling(self, use_hqq):
8185
dtype = torch.bfloat16
8286
input = torch.randn(1, 128, dtype=dtype)
8387
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype)
8488
original_output = linear(input)
85-
quantize_(linear, get_config(group_size=128))
89+
quantize_(linear, get_config(group_size=128, use_hqq=use_hqq))
8690
qw = linear.weight
8791
assert isinstance(qw, SupportsActivationPreScaling), (
8892
"Expected int4 tensor supports activation prescaling"

torchao/quantization/quant_api.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,15 +1082,15 @@ class Int4WeightOnlyConfig(AOBaseConfig):
10821082
Args:
10831083
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
10841084
size is more fine grained, choices are [256, 128, 64, 32], used in both version 1 and 2
1085-
`packing_format`: the packing format for int4 tensor, used in version 2 only
1085+
`int4_packing_format`: the packing format for int4 tensor, used in version 2 only
10861086
`int4_choose_qparams_algorithm`: variants of choose qparams algorithm to use for int4,
10871087
currently support TINYGEMM ("tinygemm") and HQQ ("hqq"), used in version 2 only
10881088
`layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`, used in version 1 only
10891089
`use_hqq`: whether to use hqq or default quantization mode, default is False, used in version 1 only
10901090
`zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE], used in version 1 only
10911091
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. used in both version 1 and 2
10921092
`preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT, used in version 1 only
1093-
`version`: version of the config to use, only subset of above args are valid for version 1, and subset of above args are valid for version 2, default is 1, see note for more details
1093+
`version`: version of the config to use, only subset of above args are valid for version 1, and subset of above args are valid for version 2, default is 2, see note for more details
10941094
10951095
Note:
10961096
Current state for Int4WeightOnlyConfig is that it supports both v1 (legacy) and v2
@@ -1147,8 +1147,12 @@ def _int4_weight_only_quantize_tensor(weight, config):
11471147
block_size = list(block_size)
11481148

11491149
if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ:
1150-
assert int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D, (
1151-
f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D curretnly"
1150+
assert int4_packing_format in [
1151+
Int4PackingFormat.TILE_PACKED_TO_4D,
1152+
Int4PackingFormat.OPAQUE,
1153+
], (
1154+
f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, "
1155+
f"it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D and Int4PackingFormat.OPAQUE currently"
11521156
)
11531157

11541158
if int4_packing_format == Int4PackingFormat.PRESHUFFLED:
@@ -1180,6 +1184,7 @@ def _int4_weight_only_quantize_tensor(weight, config):
11801184
new_weight = Int4OpaqueTensor.from_hp(
11811185
weight,
11821186
block_size,
1187+
int4_choose_qparams_algorithm=int4_choose_qparams_algorithm,
11831188
)
11841189
return new_weight
11851190
elif int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D:

torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,24 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8+
import math
89
from typing import List, Optional
910

1011
import torch
1112

1213
from torchao.quantization.quant_primitives import (
1314
MappingType,
1415
_choose_qparams_affine_tinygemm,
16+
_choose_qparams_and_quantize_affine_hqq,
1517
_quantize_affine_tinygemm,
1618
)
19+
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros
1720
from torchao.utils import (
1821
TorchAOBaseTensor,
1922
)
2023

24+
from .int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm
25+
2126
__all__ = [
2227
"Int4OpaqueTensor",
2328
]
@@ -95,6 +100,7 @@ def from_hp(
95100
cls,
96101
w: torch.Tensor,
97102
block_size: List[int],
103+
int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = Int4ChooseQParamsAlgorithm.TINYGEMM,
98104
):
99105
assert w.ndim == 2 and w.device.type == "cpu", (
100106
f"Expecting 2D tensor on CPU, but got: {w.shape} on {w.device.type}"
@@ -111,26 +117,54 @@ def from_hp(
111117
eps = 1e-6
112118
scale_dtype = None
113119
zero_point_dtype = w.dtype
114-
scale, zero_point = _choose_qparams_affine_tinygemm(
115-
w,
116-
mapping_type,
117-
block_size,
118-
target_dtype,
119-
quant_min,
120-
quant_max,
121-
eps,
122-
scale_dtype,
123-
zero_point_dtype,
124-
)
125-
int_data = _quantize_affine_tinygemm(
126-
w,
127-
block_size,
128-
scale,
129-
zero_point,
130-
target_dtype,
131-
quant_min,
132-
quant_max,
133-
)
120+
121+
# we support two paths for constructing a Int4OpaqueTensor
122+
# 1. use [hqq](https://mobiusml.github.io/hqq_blog/) algorithm to compute
123+
# scale and zero_point, then convert to the format that's compatible with tinygemm kernels
124+
# 2. don't use hqq, use default tinygemm algorithm to compute scale and zero_point
125+
#
126+
# both approach should have the same performance since both are using CPU tinygemm kernel for gemm
127+
# 1. typically will have higher accuracy compared to 2.
128+
if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ:
129+
nbits = int(math.log2(quant_max + 1))
130+
axis = 1
131+
group_size = block_size[-1]
132+
int_data, scale, zero_point, _ = _choose_qparams_and_quantize_affine_hqq(
133+
w,
134+
nbits=nbits,
135+
group_size=group_size,
136+
axis=axis,
137+
compute_dtype=zero_point_dtype,
138+
device=w.device,
139+
)
140+
int_data = int_data.to(target_dtype)
141+
else:
142+
assert (
143+
int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.TINYGEMM
144+
), (
145+
f"Unsupported Int4ChooseQParamsAlgorithm: {int4_choose_qparams_algorithm}"
146+
)
147+
148+
scale, zero_point = _choose_qparams_affine_tinygemm(
149+
w,
150+
mapping_type,
151+
block_size,
152+
target_dtype,
153+
quant_min,
154+
quant_max,
155+
eps,
156+
scale_dtype,
157+
zero_point_dtype,
158+
)
159+
int_data = _quantize_affine_tinygemm(
160+
w,
161+
block_size,
162+
scale,
163+
zero_point,
164+
target_dtype,
165+
quant_min,
166+
quant_max,
167+
)
134168
assert int_data.dtype == torch.int32, (
135169
"torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
136170
)
@@ -141,7 +175,6 @@ def from_hp(
141175

142176
scale = scale.reshape(int_data.shape[0], -1)
143177
zero_point = zero_point.reshape(int_data.shape[0], -1)
144-
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros
145178

146179
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype)
147180
return Int4OpaqueTensor(

0 commit comments

Comments
 (0)