Skip to content

Commit 28619d1

Browse files
committed
Improve QAT nvfp4 numerics
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` ghstack-source-id: 633bc65 Pull Request resolved: #3050
1 parent c2cee3e commit 28619d1

File tree

7 files changed

+191
-87
lines changed

7 files changed

+191
-87
lines changed

test/quantization/test_qat.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,7 +1910,6 @@ def _test_quantize_api_against_ptq(
19101910
quantize_(m, QATConfig(base_config, step="prepare"), filter_fn)
19111911
out_prepared = m(*example_inputs)
19121912
prepare_sqnr = compute_error(out_prepared, out_baseline)
1913-
19141913
self.assertGreaterEqual(prepare_sqnr, target_prepare_sqnr)
19151914

19161915
# compare convert
@@ -2088,21 +2087,27 @@ def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool):
20882087

20892088
self._test_quantize_api_against_ptq(
20902089
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
2091-
target_prepare_sqnr=12,
2090+
target_prepare_sqnr=float("inf"),
20922091
target_convert_sqnr=float("inf"),
20932092
)
20942093

2094+
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
20952095
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
20962096
@parametrize("use_per_tensor_scale", [True, False])
20972097
def test_qat_nvfp4(self, use_per_tensor_scale: bool):
20982098
"""
20992099
Test QAT with `NVFP4FakeQuantizeConfig`.
21002100
"""
2101+
from torchao.prototype.mx_formats import NVFP4InferenceConfig
21012102
from torchao.prototype.qat import NVFP4FakeQuantizeConfig
21022103

21032104
torch.manual_seed(self.SEED)
21042105
m = M().cuda()
21052106
baseline_model = copy.deepcopy(m)
2107+
quantize_(
2108+
baseline_model,
2109+
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
2110+
)
21062111
qat_config = QATConfig(
21072112
activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
21082113
weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
@@ -2116,7 +2121,7 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
21162121
out = m(*x)
21172122
baseline_out = baseline_model(*x)
21182123
sqnr = compute_error(out, baseline_out).item()
2119-
self.assertGreater(sqnr, 24)
2124+
self.assertGreaterEqual(sqnr, float("inf"))
21202125

21212126
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
21222127
@unittest.skipIf(

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -768,37 +768,13 @@ def nvfp4_quantize(
768768
AssertionError: If input dtype is not supported, tensor size is not
769769
divisible by block_size, tensor is not contiguous, or block_size != 16
770770
"""
771-
return _nvfp4_quantize(data_hp, block_size, per_tensor_scale)
772-
773-
774-
class _Float8Round(torch.autograd.Function):
775-
"""
776-
Cast a tensor to float8 and back to float32 with backward STE.
777-
"""
778-
779-
@staticmethod
780-
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
781-
return x.to(torch.float8_e4m3fn).to(torch.float32)
782-
783-
@staticmethod
784-
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
785-
return gy
786-
787-
788-
def _nvfp4_quantize(
789-
data_hp: torch.Tensor,
790-
block_size: int = 16,
791-
per_tensor_scale: Optional[torch.Tensor] = None,
792-
skip_dtype_cast_and_packing: bool = False,
793-
) -> tuple[torch.Tensor, torch.Tensor]:
794771
assert data_hp.dtype in (torch.bfloat16, torch.float), (
795772
f"{data_hp.dtype} not supported"
796773
)
797774
assert data_hp.size(-1) % block_size == 0, "K dim must be divisible by block_size"
798775
assert data_hp.is_contiguous(), "Only support contiguous data for now"
799776
assert block_size == 16, "NVFP4 requires block_size=16"
800777

801-
orig_dtype = data_hp.dtype
802778
orig_shape = data_hp.shape
803779
# Convert to float32 early for consistent precision with Triton implementation
804780
data_hp = data_hp.float().reshape(orig_shape[0], -1, block_size)
@@ -810,8 +786,10 @@ def _nvfp4_quantize(
810786
out_scales = None
811787
if per_tensor_scale is None:
812788
# We are doing single level scaling
813-
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX)
814-
block_scale_fp32 = _Float8Round.apply(block_scale_fp8)
789+
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to(
790+
torch.float8_e4m3fn
791+
)
792+
block_scale_fp32 = block_scale_fp8.to(torch.float32)
815793
data_scaled = data_hp / block_scale_fp32.unsqueeze(-1)
816794
out_scales = block_scale_fp8
817795
else:
@@ -823,8 +801,8 @@ def _nvfp4_quantize(
823801
scaled_block_scales = block_scale_fp32 / per_tensor_scale
824802
scaled_block_scales_fp8 = torch.clamp(
825803
scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX
826-
)
827-
scaled_block_scales_fp32 = _Float8Round.apply(scaled_block_scales_fp8)
804+
).to(torch.float8_e4m3fn)
805+
scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32)
828806
# We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
829807
# To apply to data
830808
total_scale = per_tensor_scale * scaled_block_scales_fp32
@@ -833,11 +811,8 @@ def _nvfp4_quantize(
833811

834812
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
835813
data_scaled = data_scaled.view(orig_shape)
836-
if skip_dtype_cast_and_packing:
837-
return out_scales.to(torch.float32), data_scaled.to(orig_dtype)
838-
else:
839-
data_lp = f32_to_f4_unpacked(data_scaled)
840-
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
841-
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
842-
data_lp = pack_uint4(data_lp)
843-
return out_scales.to(torch.float8_e4m3fn), data_lp
814+
data_lp = f32_to_f4_unpacked(data_scaled)
815+
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
816+
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
817+
data_lp = pack_uint4(data_lp)
818+
return out_scales, data_lp

torchao/prototype/qat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
from .nvfp4 import (
55
NVFP4FakeQuantizeConfig,
6-
NVFP4FakeQuantizer,
6+
NVFP4FakeQuantizedLinear,
77
)
88

99
__all__ = [
1010
"NVFP4FakeQuantizeConfig",
11-
"NVFP4FakeQuantizer",
11+
"NVFP4FakeQuantizedLinear",
1212
]

torchao/prototype/qat/nvfp4.py

Lines changed: 148 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
from dataclasses import dataclass
2+
from typing import Optional
23

34
import torch
45

56
from torchao.prototype.mx_formats.nvfp4_tensor import (
6-
_nvfp4_quantize,
7+
NVFP4Tensor,
8+
_addmm_nvfp4_dispatch,
79
per_tensor_amax_to_scale,
810
)
9-
from torchao.quantization.qat import (
10-
FakeQuantizeConfigBase,
11-
FakeQuantizerBase,
12-
)
11+
from torchao.quantization.qat import FakeQuantizeConfigBase
1312

1413

1514
@dataclass
@@ -23,47 +22,162 @@ class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
2322
Args:
2423
use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
2524
after the initial fp8 (e4m3) block-wise scaling (default True)
25+
use_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
26+
use_triton_kernel (bool): Whether to use triton kernels during fake quantization
2627
"""
2728

2829
use_per_tensor_scale: bool = True
30+
use_swizzled_scales: bool = False
31+
use_triton_kernel: bool = False
2932

3033

31-
class NVFP4FakeQuantizer(FakeQuantizerBase):
34+
class _NVFP4QuantizedForwardFakeQuantizedBackward(torch.autograd.Function):
3235
"""
33-
(Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
36+
Autograd function for NVFP4 quantization + addmm in low precision during forward,
37+
and fake quantization in high precision during backward.
3438
"""
3539

36-
def __init__(self, config: NVFP4FakeQuantizeConfig):
37-
super().__init__()
38-
torch._C._log_api_usage_once("torchao.quantization.qat.NVFP4FakeQuantizer")
39-
self.config = config
40+
@staticmethod
41+
def forward(
42+
ctx,
43+
_input: torch.Tensor,
44+
weight: torch.Tensor,
45+
bias: Optional[torch.Tensor],
46+
activation_config: NVFP4FakeQuantizeConfig,
47+
weight_config: NVFP4FakeQuantizeConfig,
48+
) -> torch.Tensor:
49+
# quantize input activations
50+
if activation_config.use_per_tensor_scale:
51+
tensor_amax = torch.max(torch.abs(_input))
52+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
53+
else:
54+
per_tensor_scale = None
55+
_input = NVFP4Tensor.to_nvfp4(
56+
_input,
57+
per_tensor_scale=per_tensor_scale,
58+
is_swizzled_scales=activation_config.use_swizzled_scales,
59+
use_triton_kernel=activation_config.use_triton_kernel,
60+
)
61+
62+
# quantize weights
63+
if weight_config.use_per_tensor_scale:
64+
tensor_amax = torch.max(torch.abs(weight))
65+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
66+
else:
67+
per_tensor_scale = None
68+
weight = NVFP4Tensor.to_nvfp4(
69+
weight,
70+
per_tensor_scale=per_tensor_scale,
71+
is_swizzled_scales=weight_config.use_swizzled_scales,
72+
use_triton_kernel=False,
73+
)
74+
75+
# Follow `NVFP4InferenceConfig`, always use traditional construction
76+
# for weights and set `use_triton_kernel` afterwards
77+
weight.use_triton_kernel = weight_config.use_triton_kernel
78+
79+
ctx.save_for_backward(_input, weight)
80+
81+
return _addmm_nvfp4_dispatch(
82+
_input,
83+
weight.t(),
84+
None, # aten_op, not used
85+
bias,
86+
)
87+
88+
@staticmethod
89+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
90+
_input, weight = ctx.saved_tensors
91+
assert isinstance(_input, NVFP4Tensor)
92+
assert isinstance(weight, NVFP4Tensor)
93+
_input = _input.to_dtype(_input._orig_dtype)
94+
weight = weight.to_dtype(weight._orig_dtype)
95+
grad_input = torch.mm(grad_output, weight)
96+
grad_weight = torch.mm(grad_output.t(), _input)
97+
return grad_input, grad_weight, None, None, None
98+
99+
100+
class NVFP4FakeQuantizedLinear(torch.nn.Linear):
101+
"""
102+
Linear module for fake quantized NVFP4 weights and/or activations.
103+
104+
The forward pass follows quantization and addmm numerics in `NVFP4Tensor`
105+
in lower precision exactly, while the backward pass uses dequantize
106+
(fake quantized) values in high precision.
107+
108+
Example usage::
109+
110+
from torchao.quantization import quantize_
111+
from torchao.prototype.mx_formats import NVFP4InferenceConfig
112+
113+
base_config = NVFP4InferenceConfig()
114+
quantize_(model, QATConfig(base_config, step="prepare"))
115+
# Model contains `NVFP4FakeQuantizedLinear` now
116+
117+
train_loop(model)
118+
quantize_(model, QATConfig(base_config, step="convert"))
119+
# Model contains `nn.Linear` with `NVFP4Tensor` weights now
120+
"""
121+
122+
def __init__(
123+
self,
124+
in_features: int,
125+
out_features: int,
126+
bias: bool = False,
127+
activation_config: Optional[NVFP4FakeQuantizeConfig] = None,
128+
weight_config: Optional[NVFP4FakeQuantizeConfig] = None,
129+
*args,
130+
**kwargs,
131+
):
132+
super().__init__(
133+
in_features,
134+
out_features,
135+
bias,
136+
*args,
137+
**kwargs,
138+
)
139+
if weight_config is None:
140+
raise ValueError("Must specify `weight_config`")
141+
if activation_config is None:
142+
raise ValueError("Weight only NVFP4 QAT not supported yet")
143+
self.activation_config = activation_config
144+
self.weight_config = weight_config
40145

41146
def forward(self, x: torch.Tensor) -> torch.Tensor:
42-
block_size = 16
43-
original_shape = x.shape
44147
if x.dim() == 3:
148+
batch_size = x.shape[0]
45149
x = x.view(-1, x.shape[-1])
46-
if self.config.use_per_tensor_scale:
47-
tensor_amax = torch.max(torch.abs(x))
48-
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
49150
else:
50-
per_tensor_scale = None
151+
batch_size = None
152+
fq = _NVFP4QuantizedForwardFakeQuantizedBackward.apply(
153+
x, self.weight, self.bias, self.activation_config, self.weight_config
154+
)
155+
assert fq.dtype == x.dtype
156+
if batch_size is not None:
157+
return fq.view(batch_size, -1, fq.shape[-1])
158+
else:
159+
return fq
51160

52-
# quantize
53-
scale, q = _nvfp4_quantize(
54-
x,
55-
block_size=block_size,
56-
per_tensor_scale=per_tensor_scale,
57-
skip_dtype_cast_and_packing=True,
161+
@classmethod
162+
def from_linear(
163+
cls,
164+
mod: torch.nn.Linear,
165+
activation_config: Optional[NVFP4FakeQuantizeConfig] = None,
166+
weight_config: Optional[NVFP4FakeQuantizeConfig] = None,
167+
):
168+
new_linear = NVFP4FakeQuantizedLinear(
169+
mod.in_features,
170+
mod.out_features,
171+
mod.bias is not None,
172+
activation_config=activation_config,
173+
weight_config=weight_config,
174+
device=mod.weight.device,
175+
dtype=mod.weight.dtype,
58176
)
59-
if self.config.use_per_tensor_scale:
60-
scale = scale * per_tensor_scale
61-
assert q.dtype == x.dtype
62-
assert scale.dtype == torch.float32
63-
64-
# dequantize
65-
M, K = q.shape[0], q.shape[1]
66-
q = q.view(M, K // block_size, block_size)
67-
scale = scale.view(M, K // block_size, 1)
68-
dq = q * scale
69-
return dq.view(original_shape).to(x.dtype)
177+
# In distributed training, the model may be instantiated
178+
# on the meta device, in which case there is no need to
179+
# copy the weights, and doing so will result in an error
180+
if mod.weight.device != torch.device("meta"):
181+
new_linear.weight = mod.weight
182+
new_linear.bias = mod.bias
183+
return new_linear

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,12 +444,16 @@ def _infer_fake_quantize_configs(
444444
elif isinstance(base_config, NVFP4InferenceConfig):
445445
if NVFP4MMConfig.DYNAMIC:
446446
act_config = NVFP4FakeQuantizeConfig(
447-
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale
447+
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale,
448+
use_swizzled_scales=False,
449+
use_triton_kernel=False,
448450
)
449451
else:
450452
act_config = None
451453
weight_config = NVFP4FakeQuantizeConfig(
452-
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale
454+
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale,
455+
use_swizzled_scales=True,
456+
use_triton_kernel=base_config.use_triton_kernel,
453457
)
454458
elif isinstance(base_config, Int8DynamicActivationIntxWeightConfig):
455459
assert base_config.version >= 2, "Only version 2+ is supported"

torchao/quantization/qat/fake_quantizer.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,12 @@ def __repr__(self) -> str:
6060

6161
@staticmethod
6262
def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase":
63-
# TODO: rewrite using registration API so we don't need to import here
64-
from torchao.prototype.qat import (
65-
NVFP4FakeQuantizeConfig,
66-
NVFP4FakeQuantizer,
67-
)
68-
6963
if isinstance(config, IntxFakeQuantizeConfig):
7064
return IntxFakeQuantizer(config)
7165
elif isinstance(config, Int4WeightFakeQuantizeConfig):
7266
return Int4WeightFakeQuantizer(config)
7367
elif isinstance(config, Float8FakeQuantizeConfig):
7468
return Float8FakeQuantizer(config)
75-
elif isinstance(config, NVFP4FakeQuantizeConfig):
76-
return NVFP4FakeQuantizer(config)
7769
else:
7870
raise ValueError(f"Unknown config type: {config}")
7971

0 commit comments

Comments
 (0)