Skip to content

Commit 8647095

Browse files
committed
Improve QAT nvfp4 numerics
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: 04f6bce Pull Request resolved: #3050
1 parent e1d89e7 commit 8647095

File tree

5 files changed

+40
-18
lines changed

5 files changed

+40
-18
lines changed

test/quantization/test_qat.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,7 +1910,6 @@ def _test_quantize_api_against_ptq(
19101910
quantize_(m, QATConfig(base_config, step="prepare"), filter_fn)
19111911
out_prepared = m(*example_inputs)
19121912
prepare_sqnr = compute_error(out_prepared, out_baseline)
1913-
19141913
self.assertGreaterEqual(prepare_sqnr, target_prepare_sqnr)
19151914

19161915
# compare convert
@@ -2086,9 +2085,14 @@ def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool):
20862085
"""
20872086
from torchao.prototype.mx_formats import NVFP4InferenceConfig
20882087

2088+
if use_per_tensor_scale:
2089+
target_prepare_sqnr = 36
2090+
else:
2091+
target_prepare_sqnr = float("inf")
2092+
20892093
self._test_quantize_api_against_ptq(
20902094
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
2091-
target_prepare_sqnr=12,
2095+
target_prepare_sqnr=target_prepare_sqnr,
20922096
target_convert_sqnr=float("inf"),
20932097
)
20942098

@@ -2116,7 +2120,7 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
21162120
out = m(*x)
21172121
baseline_out = baseline_model(*x)
21182122
sqnr = compute_error(out, baseline_out).item()
2119-
self.assertGreater(sqnr, 24)
2123+
self.assertGreater(sqnr, 10)
21202124

21212125
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
21222126
@unittest.skipIf(

torchao/prototype/custom_fp_utils.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def _n_ones(n: int) -> int:
2424
F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
2525

2626

27-
def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
27+
def _f32_to_floatx_unpacked(
28+
x: Tensor, ebits: int, mbits: int, fake_quantize: bool = False
29+
) -> Tensor:
2830
"""Convert FP32 numbers to sub-byte floating point numbers with the given
2931
number of exponent and mantissa bits.
3032
@@ -105,7 +107,8 @@ def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
105107
denormal_x = x + denorm_mask_float
106108
denormal_x = denormal_x.view(torch.int32)
107109
denormal_x -= denorm_mask_int
108-
denormal_x = denormal_x.to(torch.uint8)
110+
if not fake_quantize:
111+
denormal_x = denormal_x.to(torch.uint8)
109112

110113
#
111114
# branch 3: stay in normal range, adjust the exponent and round
@@ -120,31 +123,41 @@ def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
120123
normal_x += mant_odd
121124
# take the bits!
122125
normal_x = normal_x >> (MBITS_F32 - mbits)
123-
normal_x = normal_x.to(torch.uint8)
126+
if not fake_quantize:
127+
normal_x = normal_x.to(torch.uint8)
124128

125129
#
126130
# combine the branches
127131
#
128-
x = torch.full_like(x, max_int, dtype=torch.uint8)
132+
if fake_quantize:
133+
x = torch.full_like(x, max_int, dtype=torch.int32)
134+
else:
135+
x = torch.full_like(x, max_int, dtype=torch.uint8)
129136
x = torch.where(denormal_mask, denormal_x, x)
130137
x = torch.where(normal_mask, normal_x, x)
131138

132139
# add sign back
133140
sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
134-
sign_lp = sign_lp.to(torch.uint8)
141+
if not fake_quantize:
142+
sign_lp = sign_lp.to(torch.uint8)
135143
# Right shift of a negative signed integer can fill the least significant
136144
# bits with either 1s or 0s, depending on the implementation. Since PyTorch
137145
# doesn't have an uint32 dtype, we mask out these bits to get just the
138146
# f4 sign bit
139147
sign_lp = sign_lp & sign_mask
140148
x = x | sign_lp
141149

142-
return x.to(torch.uint8)
150+
if fake_quantize:
151+
return x
152+
else:
153+
return x.to(torch.uint8)
143154

144155

145156
# TODO(future): check if LUT for everything is faster than bit shifting,
146157
# especially for fp4 (only 2^4=16 unique values).
147-
def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
158+
def _floatx_unpacked_to_f32(
159+
x: Tensor, ebits: int, mbits: int, fake_quantize: bool = False
160+
) -> Tensor:
148161
"""Convert sub-byte floating point numbers with the given number of exponent
149162
and mantissa bits to FP32.
150163
@@ -154,7 +167,8 @@ def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
154167
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
155168
Output: torch.Tensor of dtype fp32 with the dequantized value
156169
"""
157-
assert x.dtype == torch.uint8
170+
if not fake_quantize:
171+
assert x.dtype == torch.uint8
158172
assert 1 + ebits + mbits <= 8
159173

160174
sign_mask = 1 << (ebits + mbits)

torchao/prototype/mx_formats/kernels.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ def get_bits(x: torch.Tensor) -> str:
6565
ZERO_POINT_FIVE_BITS_F32 = 0x3F000000
6666

6767

68-
def f32_to_f4_unpacked(x):
68+
def f32_to_f4_unpacked(x, fake_quantize: bool = False):
6969
"""
7070
Input: torch.Tensor of dtype torch.float
7171
Output: torch.Tensor of dtype torch.uint8, with bits 0-3 empty and
7272
bits 4-7 in fp4_e2m1
7373
"""
74-
return _f32_to_floatx_unpacked(x, EBITS_F4_E2M1, MBITS_F4_E2M1)
74+
return _f32_to_floatx_unpacked(x, EBITS_F4_E2M1, MBITS_F4_E2M1, fake_quantize)
7575

7676

7777
def f32_to_f6_e2m3_unpacked(x):
@@ -92,13 +92,13 @@ def f32_to_f6_e3m2_unpacked(x):
9292
return _f32_to_floatx_unpacked(x, EBITS_F6_E3M2, MBITS_F6_E3M2)
9393

9494

95-
def f4_unpacked_to_f32(x: torch.Tensor):
95+
def f4_unpacked_to_f32(x: torch.Tensor, fake_quantize: bool = False):
9696
"""
9797
Input: torch.Tensor of dtype uint8, with bits 0-3 empty and bits 4-7
9898
containing an fp4_e2m1 encoding
9999
Output: torch.Tensor of dtype fp32 with the dequantized value
100100
"""
101-
return _floatx_unpacked_to_f32(x, EBITS_F4_E2M1, MBITS_F4_E2M1)
101+
return _floatx_unpacked_to_f32(x, EBITS_F4_E2M1, MBITS_F4_E2M1, fake_quantize)
102102

103103

104104
def f6_e2m3_unpacked_to_f32(x: torch.Tensor):

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,6 @@ def _nvfp4_quantize(
798798
assert data_hp.is_contiguous(), "Only support contiguous data for now"
799799
assert block_size == 16, "NVFP4 requires block_size=16"
800800

801-
orig_dtype = data_hp.dtype
802801
orig_shape = data_hp.shape
803802
# Convert to float32 early for consistent precision with Triton implementation
804803
data_hp = data_hp.float().reshape(orig_shape[0], -1, block_size)
@@ -834,7 +833,7 @@ def _nvfp4_quantize(
834833
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
835834
data_scaled = data_scaled.view(orig_shape)
836835
if skip_dtype_cast_and_packing:
837-
return out_scales.to(torch.float32), data_scaled.to(orig_dtype)
836+
return _Float8Round.apply(out_scales), data_scaled
838837
else:
839838
data_lp = f32_to_f4_unpacked(data_scaled)
840839
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'

torchao/prototype/qat/nvfp4.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
import torch
44

5+
from torchao.prototype.mx_formats.kernels import (
6+
f4_unpacked_to_f32,
7+
f32_to_f4_unpacked,
8+
)
59
from torchao.prototype.mx_formats.nvfp4_tensor import (
610
_nvfp4_quantize,
711
per_tensor_amax_to_scale,
@@ -56,13 +60,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5660
per_tensor_scale=per_tensor_scale,
5761
skip_dtype_cast_and_packing=True,
5862
)
63+
q = f32_to_f4_unpacked(q, fake_quantize=True)
5964
if self.config.use_per_tensor_scale:
6065
scale = scale * per_tensor_scale
61-
assert q.dtype == x.dtype
6266
assert scale.dtype == torch.float32
6367

6468
# dequantize
6569
M, K = q.shape[0], q.shape[1]
70+
q = f4_unpacked_to_f32(q, fake_quantize=True)
6671
q = q.view(M, K // block_size, block_size)
6772
scale = scale.view(M, K // block_size, 1)
6873
dq = q * scale

0 commit comments

Comments
 (0)