Skip to content

Commit c7f22a1

Browse files
committed
Improve QAT nvfp4 numerics
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimick the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: ecbff90 Pull Request resolved: #3050
1 parent e1d89e7 commit c7f22a1

File tree

6 files changed

+169
-79
lines changed

6 files changed

+169
-79
lines changed

test/quantization/test_qat.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,7 +1910,6 @@ def _test_quantize_api_against_ptq(
19101910
quantize_(m, QATConfig(base_config, step="prepare"), filter_fn)
19111911
out_prepared = m(*example_inputs)
19121912
prepare_sqnr = compute_error(out_prepared, out_baseline)
1913-
19141913
self.assertGreaterEqual(prepare_sqnr, target_prepare_sqnr)
19151914

19161915
# compare convert
@@ -2088,21 +2087,27 @@ def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool):
20882087

20892088
self._test_quantize_api_against_ptq(
20902089
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
2091-
target_prepare_sqnr=12,
2090+
target_prepare_sqnr=float("inf"),
20922091
target_convert_sqnr=float("inf"),
20932092
)
20942093

2094+
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
20952095
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
20962096
@parametrize("use_per_tensor_scale", [True, False])
20972097
def test_qat_nvfp4(self, use_per_tensor_scale: bool):
20982098
"""
20992099
Test QAT with `NVFP4FakeQuantizeConfig`.
21002100
"""
2101+
from torchao.prototype.mx_formats import NVFP4InferenceConfig
21012102
from torchao.prototype.qat import NVFP4FakeQuantizeConfig
21022103

21032104
torch.manual_seed(self.SEED)
21042105
m = M().cuda()
21052106
baseline_model = copy.deepcopy(m)
2107+
quantize_(
2108+
baseline_model,
2109+
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
2110+
)
21062111
qat_config = QATConfig(
21072112
activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
21082113
weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
@@ -2116,7 +2121,7 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
21162121
out = m(*x)
21172122
baseline_out = baseline_model(*x)
21182123
sqnr = compute_error(out, baseline_out).item()
2119-
self.assertGreater(sqnr, 24)
2124+
self.assertGreaterEqual(sqnr, float("inf"))
21202125

21212126
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
21222127
@unittest.skipIf(

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -768,37 +768,13 @@ def nvfp4_quantize(
768768
AssertionError: If input dtype is not supported, tensor size is not
769769
divisible by block_size, tensor is not contiguous, or block_size != 16
770770
"""
771-
return _nvfp4_quantize(data_hp, block_size, per_tensor_scale)
772-
773-
774-
class _Float8Round(torch.autograd.Function):
775-
"""
776-
Cast a tensor to float8 and back to float32 with backward STE.
777-
"""
778-
779-
@staticmethod
780-
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
781-
return x.to(torch.float8_e4m3fn).to(torch.float32)
782-
783-
@staticmethod
784-
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
785-
return gy
786-
787-
788-
def _nvfp4_quantize(
789-
data_hp: torch.Tensor,
790-
block_size: int = 16,
791-
per_tensor_scale: Optional[torch.Tensor] = None,
792-
skip_dtype_cast_and_packing: bool = False,
793-
) -> tuple[torch.Tensor, torch.Tensor]:
794771
assert data_hp.dtype in (torch.bfloat16, torch.float), (
795772
f"{data_hp.dtype} not supported"
796773
)
797774
assert data_hp.size(-1) % block_size == 0, "K dim must be divisible by block_size"
798775
assert data_hp.is_contiguous(), "Only support contiguous data for now"
799776
assert block_size == 16, "NVFP4 requires block_size=16"
800777

801-
orig_dtype = data_hp.dtype
802778
orig_shape = data_hp.shape
803779
# Convert to float32 early for consistent precision with Triton implementation
804780
data_hp = data_hp.float().reshape(orig_shape[0], -1, block_size)
@@ -810,8 +786,10 @@ def _nvfp4_quantize(
810786
out_scales = None
811787
if per_tensor_scale is None:
812788
# We are doing single level scaling
813-
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX)
814-
block_scale_fp32 = _Float8Round.apply(block_scale_fp8)
789+
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to(
790+
torch.float8_e4m3fn
791+
)
792+
block_scale_fp32 = block_scale_fp8.to(torch.float32)
815793
data_scaled = data_hp / block_scale_fp32.unsqueeze(-1)
816794
out_scales = block_scale_fp8
817795
else:
@@ -823,8 +801,8 @@ def _nvfp4_quantize(
823801
scaled_block_scales = block_scale_fp32 / per_tensor_scale
824802
scaled_block_scales_fp8 = torch.clamp(
825803
scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX
826-
)
827-
scaled_block_scales_fp32 = _Float8Round.apply(scaled_block_scales_fp8)
804+
).to(torch.float8_e4m3fn)
805+
scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32)
828806
# We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
829807
# To apply to data
830808
total_scale = per_tensor_scale * scaled_block_scales_fp32
@@ -833,11 +811,8 @@ def _nvfp4_quantize(
833811

834812
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
835813
data_scaled = data_scaled.view(orig_shape)
836-
if skip_dtype_cast_and_packing:
837-
return out_scales.to(torch.float32), data_scaled.to(orig_dtype)
838-
else:
839-
data_lp = f32_to_f4_unpacked(data_scaled)
840-
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
841-
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
842-
data_lp = pack_uint4(data_lp)
843-
return out_scales.to(torch.float8_e4m3fn), data_lp
814+
data_lp = f32_to_f4_unpacked(data_scaled)
815+
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
816+
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
817+
data_lp = pack_uint4(data_lp)
818+
return out_scales.to(torch.float8_e4m3fn), data_lp

torchao/prototype/qat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
from .nvfp4 import (
55
NVFP4FakeQuantizeConfig,
6-
NVFP4FakeQuantizer,
6+
NVFP4FakeQuantizedLinear,
77
)
88

99
__all__ = [
1010
"NVFP4FakeQuantizeConfig",
11-
"NVFP4FakeQuantizer",
11+
"NVFP4FakeQuantizedLinear",
1212
]

torchao/prototype/qat/nvfp4.py

Lines changed: 126 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
from dataclasses import dataclass
2+
from typing import Optional
23

34
import torch
45

56
from torchao.prototype.mx_formats.nvfp4_tensor import (
6-
_nvfp4_quantize,
7+
NVFP4Tensor,
8+
_addmm_nvfp4_dispatch,
79
per_tensor_amax_to_scale,
810
)
9-
from torchao.quantization.qat import (
10-
FakeQuantizeConfigBase,
11-
FakeQuantizerBase,
12-
)
11+
from torchao.quantization.qat import FakeQuantizeConfigBase
1312

1413

1514
@dataclass
@@ -23,47 +22,140 @@ class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
2322
Args:
2423
use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
2524
after the initial fp8 (e4m3) block-wise scaling (default True)
25+
use_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
26+
use_triton_kernel (bool): Whether to use triton kernels during fake quantization
2627
"""
2728

2829
use_per_tensor_scale: bool = True
30+
use_swizzled_scales: bool = False
31+
use_triton_kernel: bool = False
2932

3033

31-
class NVFP4FakeQuantizer(FakeQuantizerBase):
34+
class _NVFP4FakeQuantizedForward(torch.autograd.Function):
3235
"""
33-
(Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
36+
TODO: write me
3437
"""
3538

36-
def __init__(self, config: NVFP4FakeQuantizeConfig):
37-
super().__init__()
38-
torch._C._log_api_usage_once("torchao.quantization.qat.NVFP4FakeQuantizer")
39-
self.config = config
39+
@staticmethod
40+
def forward(
41+
ctx,
42+
_input: torch.Tensor,
43+
weight: torch.Tensor,
44+
bias: Optional[torch.Tensor],
45+
activation_config: NVFP4FakeQuantizeConfig,
46+
weight_config: NVFP4FakeQuantizeConfig,
47+
) -> torch.Tensor:
48+
ctx.save_for_backward(_input, weight)
4049

41-
def forward(self, x: torch.Tensor) -> torch.Tensor:
42-
block_size = 16
43-
original_shape = x.shape
44-
if x.dim() == 3:
45-
x = x.view(-1, x.shape[-1])
46-
if self.config.use_per_tensor_scale:
47-
tensor_amax = torch.max(torch.abs(x))
50+
# quantize input activations
51+
if activation_config.use_per_tensor_scale:
52+
tensor_amax = torch.max(torch.abs(_input))
4853
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
4954
else:
5055
per_tensor_scale = None
56+
_input = NVFP4Tensor.to_nvfp4(
57+
_input,
58+
per_tensor_scale=per_tensor_scale,
59+
is_swizzled_scales=activation_config.use_swizzled_scales,
60+
use_triton_kernel=activation_config.use_triton_kernel,
61+
)
5162

52-
# quantize
53-
scale, q = _nvfp4_quantize(
54-
x,
55-
block_size=block_size,
63+
# quantize weights
64+
if weight_config.use_per_tensor_scale:
65+
tensor_amax = torch.max(torch.abs(weight))
66+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
67+
else:
68+
per_tensor_scale = None
69+
weight = NVFP4Tensor.to_nvfp4(
70+
weight,
5671
per_tensor_scale=per_tensor_scale,
57-
skip_dtype_cast_and_packing=True,
72+
is_swizzled_scales=weight_config.use_swizzled_scales,
73+
use_triton_kernel=False,
74+
)
75+
76+
# Follow `NVFP4InferenceConfig`, always use traditional construction
77+
# for weights and set `use_triton_kernel` afterwards
78+
weight.use_triton_kernel = weight_config.use_triton_kernel
79+
80+
return _addmm_nvfp4_dispatch(
81+
_input,
82+
weight.t(),
83+
None, # aten_op, not used
84+
bias,
85+
)
86+
87+
@staticmethod
88+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
89+
_input, weight = ctx.saved_tensors
90+
grad_input = torch.mm(grad_output, weight)
91+
grad_weight = torch.mm(grad_output.t(), _input)
92+
return grad_input, grad_weight, None, None, None
93+
94+
95+
class NVFP4FakeQuantizedLinear(torch.nn.Linear):
96+
"""
97+
TODO: write me
98+
"""
99+
100+
def __init__(
101+
self,
102+
in_features: int,
103+
out_features: int,
104+
bias: bool = False,
105+
activation_config: Optional[NVFP4FakeQuantizeConfig] = None,
106+
weight_config: Optional[NVFP4FakeQuantizeConfig] = None,
107+
*args,
108+
**kwargs,
109+
):
110+
super().__init__(
111+
in_features,
112+
out_features,
113+
bias,
114+
*args,
115+
**kwargs,
116+
)
117+
if weight_config is None:
118+
raise ValueError("Must specify `weight_config`")
119+
if activation_config is None:
120+
raise ValueError("Weight only NVFP4 QAT not supported yet")
121+
self.activation_config = activation_config
122+
self.weight_config = weight_config
123+
124+
def forward(self, x: torch.Tensor) -> torch.Tensor:
125+
if x.dim() == 3:
126+
batch_size = x.shape[0]
127+
x = x.view(-1, x.shape[-1])
128+
else:
129+
batch_size = None
130+
fq = _NVFP4FakeQuantizedForward.apply(
131+
x, self.weight, self.bias, self.activation_config, self.weight_config
132+
)
133+
assert fq.dtype == x.dtype
134+
if batch_size is not None:
135+
return fq.view(batch_size, -1, fq.shape[-1])
136+
else:
137+
return fq
138+
139+
@classmethod
140+
def from_linear(
141+
cls,
142+
mod: torch.nn.Linear,
143+
activation_config: Optional[NVFP4FakeQuantizeConfig] = None,
144+
weight_config: Optional[NVFP4FakeQuantizeConfig] = None,
145+
):
146+
new_linear = NVFP4FakeQuantizedLinear(
147+
mod.in_features,
148+
mod.out_features,
149+
mod.bias is not None,
150+
activation_config=activation_config,
151+
weight_config=weight_config,
152+
device=mod.weight.device,
153+
dtype=mod.weight.dtype,
58154
)
59-
if self.config.use_per_tensor_scale:
60-
scale = scale * per_tensor_scale
61-
assert q.dtype == x.dtype
62-
assert scale.dtype == torch.float32
63-
64-
# dequantize
65-
M, K = q.shape[0], q.shape[1]
66-
q = q.view(M, K // block_size, block_size)
67-
scale = scale.view(M, K // block_size, 1)
68-
dq = q * scale
69-
return dq.view(original_shape).to(x.dtype)
155+
# In distributed training, the model may be instantiated
156+
# on the meta device, in which case there is no need to
157+
# copy the weights, and doing so will result in an error
158+
if mod.weight.device != torch.device("meta"):
159+
new_linear.weight = mod.weight
160+
new_linear.bias = mod.bias
161+
return new_linear

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,12 +444,16 @@ def _infer_fake_quantize_configs(
444444
elif isinstance(base_config, NVFP4InferenceConfig):
445445
if NVFP4MMConfig.DYNAMIC:
446446
act_config = NVFP4FakeQuantizeConfig(
447-
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale
447+
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale,
448+
use_swizzled_scales=False,
449+
use_triton_kernel=False,
448450
)
449451
else:
450452
act_config = None
451453
weight_config = NVFP4FakeQuantizeConfig(
452-
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale
454+
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale,
455+
use_swizzled_scales=True,
456+
use_triton_kernel=base_config.use_triton_kernel,
453457
)
454458
elif isinstance(base_config, Int8DynamicActivationIntxWeightConfig):
455459
assert base_config.version >= 2, "Only version 2+ is supported"

torchao/quantization/qat/linear.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,27 @@ def to_linear(self) -> torch.nn.Linear:
130130
new_linear.bias = self.bias
131131
return new_linear
132132

133-
@classmethod
133+
@staticmethod
134134
def from_linear(
135-
cls,
136135
mod: torch.nn.Linear,
137136
activation_config: Optional[FakeQuantizeConfigBase] = None,
138137
weight_config: Optional[FakeQuantizeConfigBase] = None,
139138
):
139+
# TODO: rewrite this using a registration API so
140+
# specific quantization schemes do not leak here
141+
from torchao.prototype.qat import (
142+
NVFP4FakeQuantizeConfig,
143+
NVFP4FakeQuantizedLinear,
144+
)
145+
146+
if isinstance(weight_config, NVFP4FakeQuantizeConfig):
147+
assert activation_config is None or isinstance(
148+
activation_config, NVFP4FakeQuantizeConfig
149+
)
150+
return NVFP4FakeQuantizedLinear.from_linear(
151+
mod, activation_config, weight_config
152+
)
153+
140154
new_linear = FakeQuantizedLinear(
141155
mod.in_features,
142156
mod.out_features,

0 commit comments

Comments
 (0)