Skip to content

Commit e21e5bf

Browse files
committed
Improve QAT fp8-int4 numerics
**Summary:** This commit improved the prepare vs convert SQNR of fp8-int4 QAT from 12 to 22. This is achieved by mimicking the numerics of the target FBGEMM fp8-int4 kernel more closely. In particular, FBGEMM first quantizes the weights to fp8, and then uses max abs values to compute the scale, which is significantly different from what torchao's quant primitives do. **Test Plan:** ``` python test/quantization/test_qat.py -k test_fbgemm_fp8_primitives python test/quantization/test_qat.py -k test_fbgemm_int4_primitives python test/quantization/test_qat.py -k test_quantize_api_fp8_int4 ```
1 parent 4700fe8 commit e21e5bf

File tree

3 files changed

+228
-11
lines changed

3 files changed

+228
-11
lines changed

test/quantization/test_qat.py

Lines changed: 139 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
)
5050
from torchao.quantization.qat.fake_quantize_config import (
5151
Float8FakeQuantizeConfig,
52+
Int4WeightFBGEMMFakeQuantizeConfig,
5253
IntxFakeQuantizeConfig,
5354
)
5455
from torchao.quantization.qat.fake_quantizer import (
@@ -1929,7 +1930,7 @@ def test_quantize_api_fp8_int4(self):
19291930
"""
19301931
self._test_quantize_api_against_ptq(
19311932
Float8DynamicActivationInt4WeightConfig(),
1932-
target_prepare_sqnr=12,
1933+
target_prepare_sqnr=22,
19331934
target_convert_sqnr=float("inf"),
19341935
)
19351936

@@ -1950,6 +1951,19 @@ def test_quantize_api_int4(self, version: int):
19501951
target_convert_sqnr=float("inf"),
19511952
)
19521953

1954+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1955+
def test_quantize_api_int8_int4(self):
1956+
"""
1957+
Test the following:
1958+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
1959+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))
1960+
"""
1961+
self._test_quantize_api_against_ptq(
1962+
Int8DynamicActivationInt4WeightConfig(group_size=32),
1963+
target_prepare_sqnr=30,
1964+
target_convert_sqnr=float("inf"),
1965+
)
1966+
19531967
def test_infer_fp8_int4_config(self):
19541968
"""
19551969
Test that fake quantize configs are correctly inferred from
@@ -1964,10 +1978,9 @@ def test_infer_fp8_int4_config(self):
19641978
self.assertIsInstance(act_config, Float8FakeQuantizeConfig)
19651979
self.assertEqual(act_config.dtype, torch.float8_e4m3fn)
19661980
self.assertIsInstance(act_config.granularity, PerRow)
1967-
self.assertIsInstance(weight_config, IntxFakeQuantizeConfig)
1968-
self.assertEqual(weight_config.dtype, torch.int4)
1981+
self.assertIsInstance(weight_config, Int4WeightFBGEMMFakeQuantizeConfig)
19691982
self.assertEqual(weight_config.group_size, 128)
1970-
self.assertTrue(weight_config.is_symmetric)
1983+
self.assertEqual(weight_config.activation_dtype, torch.float8_e4m3fn)
19711984

19721985
def test_infer_int4_weight_only_config(self):
19731986
"""
@@ -2033,6 +2046,128 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
20332046
sqnr = compute_error(out, baseline_out).item()
20342047
self.assertGreater(sqnr, 24)
20352048

2049+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
2050+
@unittest.skipIf(
2051+
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
2052+
)
2053+
def test_fbgemm_fp8_primitives(self):
2054+
"""
2055+
Compare numerics between:
2056+
(1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_fp8_row
2057+
(2) Our reference QAT version in `Float8FakeQuantizer`
2058+
"""
2059+
from fbgemm_gpu.experimental.gen_ai.quantize import quantize_fp8_row
2060+
2061+
from torchao.quantization.quant_primitives import (
2062+
_choose_scale_float8,
2063+
_quantize_affine_float8,
2064+
)
2065+
2066+
x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda()
2067+
x2 = copy.deepcopy(x1)
2068+
2069+
# (1) Just call `quantize_fp8_row`
2070+
(q1, scale1) = quantize_fp8_row(x1)
2071+
2072+
# (2) Our reference implementation for QAT without the dequantize
2073+
scale2 = _choose_scale_float8(
2074+
x2,
2075+
(1, x2.shape[-1]),
2076+
torch.float8_e4m3fn,
2077+
hp_value_lb=1e-12,
2078+
)
2079+
q2 = _quantize_affine_float8(
2080+
x2, scale2, torch.float8_e4m3fn, cast_to_float8_dtype=False
2081+
)
2082+
sqnr = compute_error(q1.to(torch.float32), q2.to(torch.float32))
2083+
scale_sqnr = compute_error(
2084+
scale1.to(torch.float32).flatten(),
2085+
scale2.to(torch.float32).flatten(),
2086+
)
2087+
self.assertGreater(sqnr, 30)
2088+
self.assertGreater(scale_sqnr, 50)
2089+
2090+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
2091+
@unittest.skipIf(
2092+
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
2093+
)
2094+
def test_fbgemm_int4_primitives(self):
2095+
"""
2096+
Compare numerics between:
2097+
(1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle
2098+
(2) Our reference QAT version in `Int4WeightFBGEMMFakeQuantizer`
2099+
"""
2100+
from fbgemm_gpu.experimental.gen_ai.quantize import (
2101+
int4_row_quantize,
2102+
pack_int4,
2103+
quantize_fp8_row,
2104+
quantize_int4_preshuffle,
2105+
)
2106+
2107+
from torchao.quantization.quant_primitives import (
2108+
_choose_scale_float8,
2109+
_quantize_affine_float8,
2110+
_quantize_affine_no_dtype_cast,
2111+
)
2112+
2113+
group_size = 128
2114+
x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda()
2115+
x2 = copy.deepcopy(x1)
2116+
x3 = copy.deepcopy(x1)
2117+
2118+
# (1) Just call `quantize_int4_preshuffle`
2119+
(q1, (scale1, _)) = quantize_int4_preshuffle(x1, group_size, dtype="fp8")
2120+
2121+
# (2) Call `quantize_int4_preshuffle` but skip packing and shuffling
2122+
(q2, _) = quantize_fp8_row(x2)
2123+
(q2, scale2) = int4_row_quantize(q2, group_size)
2124+
2125+
# (3) Reference implementation for QAT without the dequantize
2126+
fp8_scale = _choose_scale_float8(
2127+
x3,
2128+
(1, x3.shape[-1]),
2129+
torch.float8_e4m3fn,
2130+
hp_value_lb=1e-12,
2131+
)
2132+
x3_fp8 = _quantize_affine_float8(x3, fp8_scale, torch.float8_e4m3fn)
2133+
x3_fp8 = x3_fp8.to(torch.float32)
2134+
x3_fp8_grouped = x3_fp8.view(x3_fp8.shape[0], -1, group_size)
2135+
max_abs = torch.amax(torch.abs(x3_fp8_grouped), dim=-1, keepdim=False)
2136+
scale = torch.clamp(max_abs / 8, min=1e-6)
2137+
zero_point = torch.zeros_like(scale)
2138+
q3 = _quantize_affine_no_dtype_cast(
2139+
x3_fp8,
2140+
(1, group_size),
2141+
scale,
2142+
zero_point,
2143+
quant_min=-8,
2144+
quant_max=7,
2145+
)
2146+
scale3 = scale
2147+
2148+
def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
2149+
t = pack_int4(t.to(torch.int8))
2150+
return torch.ops.fbgemm.preshuffle_i4(t, scale.to(torch.float8_e4m3fn))[0]
2151+
2152+
# First, sanity check that shuffle_and_pack(q2) == q1
2153+
torch.testing.assert_close(q1, shuffle_and_pack(q2, scale2), atol=0, rtol=0)
2154+
2155+
# Now check q2 vs q3 with and without shuffle
2156+
sqnr_q2_q3 = compute_error(q2.to(torch.float32), q3.to(torch.float32))
2157+
sqnr_q2_q3_preshuffle = compute_error(
2158+
shuffle_and_pack(q2, scale2).to(torch.float32),
2159+
shuffle_and_pack(q3, scale3).to(torch.float32),
2160+
)
2161+
self.assertGreater(sqnr_q2_q3, 32)
2162+
self.assertGreater(sqnr_q2_q3_preshuffle, 32)
2163+
2164+
# Now check shuffle_and_pack(q3) vs q1
2165+
sqnr_q1_q3_preshuffle = compute_error(
2166+
q1.to(torch.float32),
2167+
shuffle_and_pack(q3, scale3).to(torch.float32),
2168+
)
2169+
self.assertGreater(sqnr_q1_q3_preshuffle, 32)
2170+
20362171

20372172
instantiate_parametrized_tests(TestQAT)
20382173

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,24 @@ def __post_init__(self):
7777
)
7878

7979

80+
@dataclass
81+
class Int4WeightFBGEMMFakeQuantizeConfig(FakeQuantizeConfigBase):
82+
"""
83+
Config for int4 weight fake quantization that targets the numerics in the following FBGEMM kernel:
84+
torch.ops.fbgemm.f8i4bf16_shuffled
85+
86+
Currently this only supports float8 input activations. In the future, we may extend this
87+
to support bfloat16 as well.
88+
"""
89+
90+
group_size: int = 128
91+
activation_dtype: torch.dtype = e4m3_dtype
92+
93+
def __post_init__(self):
94+
if self.activation_dtype != e4m3_dtype:
95+
raise ValueError(f"Only {e4m3_dtype} activation is supported currently")
96+
97+
8098
@dataclass
8199
class IntxFakeQuantizeConfig(FakeQuantizeConfigBase):
82100
"""
@@ -404,10 +422,9 @@ def _infer_fake_quantize_configs(
404422
dtype=torch.float8_e4m3fn,
405423
granularity=PerRow(),
406424
)
407-
weight_config = IntxFakeQuantizeConfig(
408-
dtype=torch.int4,
425+
weight_config = Int4WeightFBGEMMFakeQuantizeConfig(
409426
group_size=128,
410-
is_symmetric=True,
427+
activation_dtype=e4m3_dtype,
411428
)
412429
elif isinstance(base_config, NVFP4InferenceConfig):
413430
# Note: today the PTQ config does not allow the user to specify

torchao/quantization/qat/fake_quantizer.py

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Optional
7+
from typing import Optional, Tuple
88

99
import torch
1010

1111
from torchao.quantization.granularity import (
1212
PerAxis,
1313
PerGroup,
14+
PerRow,
1415
PerToken,
1516
)
1617
from torchao.quantization.observer import get_block_size
@@ -20,6 +21,7 @@
2021
MappingType,
2122
_choose_scale_float8,
2223
_dequantize_affine_float8,
24+
_fake_quantize_affine,
2325
_quantize_affine_float8,
2426
_Round,
2527
choose_qparams_affine,
@@ -33,6 +35,7 @@
3335
from .fake_quantize_config import (
3436
FakeQuantizeConfigBase,
3537
Float8FakeQuantizeConfig,
38+
Int4WeightFBGEMMFakeQuantizeConfig,
3639
IntxFakeQuantizeConfig,
3740
)
3841
from .utils import (
@@ -65,6 +68,8 @@ def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase":
6568

6669
if isinstance(config, IntxFakeQuantizeConfig):
6770
return IntxFakeQuantizer(config)
71+
elif isinstance(config, Int4WeightFBGEMMFakeQuantizeConfig):
72+
return Int4WeightFBGEMMFakeQuantizer(config)
6873
elif isinstance(config, Float8FakeQuantizeConfig):
6974
return Float8FakeQuantizer(config)
7075
elif isinstance(config, NVFP4FakeQuantizeConfig):
@@ -93,13 +98,73 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9398
hp_value_lb=self.config.hp_value_lb,
9499
hp_value_ub=self.config.hp_value_ub,
95100
)
96-
q = _quantize_affine_float8(
97-
x, scale, self.config.dtype, cast_to_float8_dtype=False
98-
)
101+
q = _quantize_affine_float8(x, scale, self.config.dtype)
99102
dq = _dequantize_affine_float8(q, scale, original_dtype)
100103
return dq
101104

102105

106+
class Int4WeightFBGEMMFakeQuantizer(FakeQuantizerBase):
107+
"""
108+
Generic module for applying int4 fake quantization to a weight tensor,
109+
targeting the following FBGEMM kernel:
110+
torch.ops.fbgemm.f8i4bf16_shuffled
111+
"""
112+
113+
def __init__(self, config: Int4WeightFBGEMMFakeQuantizeConfig):
114+
super().__init__()
115+
self.config = config
116+
torch._C._log_api_usage_once(
117+
"torchao.quantization.qat.Int4WeightFBGEMMFakeQuantizer"
118+
)
119+
120+
def forward(self, w: torch.Tensor) -> torch.Tensor:
121+
return self._forward(w)[0]
122+
123+
def _forward(self, w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
124+
"""
125+
Apply int4 fake quantization to the weight tensor, using the following as a reference:
126+
https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L112
127+
128+
Currently, we expect the activations to always be rowwise float8.
129+
130+
Returns a 2-tuple of (fake quantized weight, per group scale).
131+
"""
132+
assert w.dim() == 2
133+
assert self.config.activation_dtype == torch.float8_e4m3fn
134+
135+
# First quantize weights to fp8 per row
136+
# This simulates the numerics of fbgemm_gpu.experimental.gen_ai.quantize.quantize_fp8_row
137+
per_row_block_size = get_block_size(w.shape, PerRow())
138+
fp8_scale = _choose_scale_float8(
139+
w,
140+
per_row_block_size,
141+
torch.float8_e4m3fn,
142+
hp_value_lb=1e-12,
143+
)
144+
w_fp8 = _quantize_affine_float8(w, fp8_scale, torch.float8_e4m3fn)
145+
w_fp8 = _dequantize_affine_float8(w_fp8, fp8_scale, w.dtype)
146+
147+
# Now quantize to int4 per group
148+
# This simulates the numerics of fbgemm_gpu.experimental.gen_ai.quantize.int4_row_quantize
149+
eps = 1e-6
150+
fbgemm_scale_quant_max = 8
151+
w_fp8_grouped = w_fp8.view(w_fp8.shape[0], -1, self.config.group_size)
152+
max_abs = torch.amax(torch.abs(w_fp8_grouped), dim=-1, keepdim=False)
153+
scale = torch.clamp(max_abs / fbgemm_scale_quant_max, min=eps)
154+
zero_point = torch.zeros_like(scale)
155+
per_group_block_size = (1, self.config.group_size)
156+
fq = _fake_quantize_affine(
157+
w_fp8,
158+
per_group_block_size,
159+
scale,
160+
zero_point,
161+
quant_dtype=torch.int8,
162+
quant_min=-8,
163+
quant_max=7,
164+
)
165+
return (fq.to(w.dtype), scale)
166+
167+
103168
class IntxFakeQuantizer(FakeQuantizerBase):
104169
"""
105170
Generic module for applying integer fake quantization to a tensor, as specified in the config.

0 commit comments

Comments
 (0)