Skip to content

Commit 3ed0110

Browse files
committed
Improve QAT nvfp4 numerics
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimick the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: a707a59 Pull Request resolved: #3050
1 parent e1d89e7 commit 3ed0110

File tree

7 files changed

+173
-87
lines changed

7 files changed

+173
-87
lines changed

test/quantization/test_qat.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,7 +1910,6 @@ def _test_quantize_api_against_ptq(
19101910
quantize_(m, QATConfig(base_config, step="prepare"), filter_fn)
19111911
out_prepared = m(*example_inputs)
19121912
prepare_sqnr = compute_error(out_prepared, out_baseline)
1913-
19141913
self.assertGreaterEqual(prepare_sqnr, target_prepare_sqnr)
19151914

19161915
# compare convert
@@ -2088,21 +2087,27 @@ def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool):
20882087

20892088
self._test_quantize_api_against_ptq(
20902089
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
2091-
target_prepare_sqnr=12,
2090+
target_prepare_sqnr=float("inf"),
20922091
target_convert_sqnr=float("inf"),
20932092
)
20942093

2094+
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
20952095
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
20962096
@parametrize("use_per_tensor_scale", [True, False])
20972097
def test_qat_nvfp4(self, use_per_tensor_scale: bool):
20982098
"""
20992099
Test QAT with `NVFP4FakeQuantizeConfig`.
21002100
"""
2101+
from torchao.prototype.mx_formats import NVFP4InferenceConfig
21012102
from torchao.prototype.qat import NVFP4FakeQuantizeConfig
21022103

21032104
torch.manual_seed(self.SEED)
21042105
m = M().cuda()
21052106
baseline_model = copy.deepcopy(m)
2107+
quantize_(
2108+
baseline_model,
2109+
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
2110+
)
21062111
qat_config = QATConfig(
21072112
activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
21082113
weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
@@ -2116,7 +2121,7 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
21162121
out = m(*x)
21172122
baseline_out = baseline_model(*x)
21182123
sqnr = compute_error(out, baseline_out).item()
2119-
self.assertGreater(sqnr, 24)
2124+
self.assertGreaterEqual(sqnr, float("inf"))
21202125

21212126
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
21222127
@unittest.skipIf(

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -768,37 +768,13 @@ def nvfp4_quantize(
768768
AssertionError: If input dtype is not supported, tensor size is not
769769
divisible by block_size, tensor is not contiguous, or block_size != 16
770770
"""
771-
return _nvfp4_quantize(data_hp, block_size, per_tensor_scale)
772-
773-
774-
class _Float8Round(torch.autograd.Function):
775-
"""
776-
Cast a tensor to float8 and back to float32 with backward STE.
777-
"""
778-
779-
@staticmethod
780-
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
781-
return x.to(torch.float8_e4m3fn).to(torch.float32)
782-
783-
@staticmethod
784-
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
785-
return gy
786-
787-
788-
def _nvfp4_quantize(
789-
data_hp: torch.Tensor,
790-
block_size: int = 16,
791-
per_tensor_scale: Optional[torch.Tensor] = None,
792-
skip_dtype_cast_and_packing: bool = False,
793-
) -> tuple[torch.Tensor, torch.Tensor]:
794771
assert data_hp.dtype in (torch.bfloat16, torch.float), (
795772
f"{data_hp.dtype} not supported"
796773
)
797774
assert data_hp.size(-1) % block_size == 0, "K dim must be divisible by block_size"
798775
assert data_hp.is_contiguous(), "Only support contiguous data for now"
799776
assert block_size == 16, "NVFP4 requires block_size=16"
800777

801-
orig_dtype = data_hp.dtype
802778
orig_shape = data_hp.shape
803779
# Convert to float32 early for consistent precision with Triton implementation
804780
data_hp = data_hp.float().reshape(orig_shape[0], -1, block_size)
@@ -810,8 +786,10 @@ def _nvfp4_quantize(
810786
out_scales = None
811787
if per_tensor_scale is None:
812788
# We are doing single level scaling
813-
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX)
814-
block_scale_fp32 = _Float8Round.apply(block_scale_fp8)
789+
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to(
790+
torch.float8_e4m3fn
791+
)
792+
block_scale_fp32 = block_scale_fp8.to(torch.float32)
815793
data_scaled = data_hp / block_scale_fp32.unsqueeze(-1)
816794
out_scales = block_scale_fp8
817795
else:
@@ -823,8 +801,8 @@ def _nvfp4_quantize(
823801
scaled_block_scales = block_scale_fp32 / per_tensor_scale
824802
scaled_block_scales_fp8 = torch.clamp(
825803
scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX
826-
)
827-
scaled_block_scales_fp32 = _Float8Round.apply(scaled_block_scales_fp8)
804+
).to(torch.float8_e4m3fn)
805+
scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32)
828806
# We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
829807
# To apply to data
830808
total_scale = per_tensor_scale * scaled_block_scales_fp32
@@ -833,11 +811,8 @@ def _nvfp4_quantize(
833811

834812
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
835813
data_scaled = data_scaled.view(orig_shape)
836-
if skip_dtype_cast_and_packing:
837-
return out_scales.to(torch.float32), data_scaled.to(orig_dtype)
838-
else:
839-
data_lp = f32_to_f4_unpacked(data_scaled)
840-
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
841-
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
842-
data_lp = pack_uint4(data_lp)
843-
return out_scales.to(torch.float8_e4m3fn), data_lp
814+
data_lp = f32_to_f4_unpacked(data_scaled)
815+
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
816+
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
817+
data_lp = pack_uint4(data_lp)
818+
return out_scales, data_lp

torchao/prototype/qat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
from .nvfp4 import (
55
NVFP4FakeQuantizeConfig,
6-
NVFP4FakeQuantizer,
6+
NVFP4FakeQuantizedLinear,
77
)
88

99
__all__ = [
1010
"NVFP4FakeQuantizeConfig",
11-
"NVFP4FakeQuantizer",
11+
"NVFP4FakeQuantizedLinear",
1212
]

torchao/prototype/qat/nvfp4.py

Lines changed: 130 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
from dataclasses import dataclass
2+
from typing import Optional
23

34
import torch
45

56
from torchao.prototype.mx_formats.nvfp4_tensor import (
6-
_nvfp4_quantize,
7+
NVFP4Tensor,
8+
_addmm_nvfp4_dispatch,
79
per_tensor_amax_to_scale,
810
)
9-
from torchao.quantization.qat import (
10-
FakeQuantizeConfigBase,
11-
FakeQuantizerBase,
12-
)
11+
from torchao.quantization.qat import FakeQuantizeConfigBase
1312

1413

1514
@dataclass
@@ -23,47 +22,144 @@ class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
2322
Args:
2423
use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
2524
after the initial fp8 (e4m3) block-wise scaling (default True)
25+
use_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
26+
use_triton_kernel (bool): Whether to use triton kernels during fake quantization
2627
"""
2728

2829
use_per_tensor_scale: bool = True
30+
use_swizzled_scales: bool = False
31+
use_triton_kernel: bool = False
32+
33+
34+
class _NVFP4FakeQuantizedForward(torch.autograd.Function):
35+
"""
36+
TODO: write me
37+
"""
38+
39+
@staticmethod
40+
def forward(
41+
ctx,
42+
_input: torch.Tensor,
43+
weight: torch.Tensor,
44+
bias: Optional[torch.Tensor],
45+
activation_config: NVFP4FakeQuantizeConfig,
46+
weight_config: NVFP4FakeQuantizeConfig,
47+
) -> torch.Tensor:
48+
# quantize input activations
49+
if activation_config.use_per_tensor_scale:
50+
tensor_amax = torch.max(torch.abs(_input))
51+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
52+
else:
53+
per_tensor_scale = None
54+
_input = NVFP4Tensor.to_nvfp4(
55+
_input,
56+
per_tensor_scale=per_tensor_scale,
57+
is_swizzled_scales=activation_config.use_swizzled_scales,
58+
use_triton_kernel=activation_config.use_triton_kernel,
59+
)
60+
61+
# quantize weights
62+
if weight_config.use_per_tensor_scale:
63+
tensor_amax = torch.max(torch.abs(weight))
64+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
65+
else:
66+
per_tensor_scale = None
67+
weight = NVFP4Tensor.to_nvfp4(
68+
weight,
69+
per_tensor_scale=per_tensor_scale,
70+
is_swizzled_scales=weight_config.use_swizzled_scales,
71+
use_triton_kernel=False,
72+
)
73+
74+
# Follow `NVFP4InferenceConfig`, always use traditional construction
75+
# for weights and set `use_triton_kernel` afterwards
76+
weight.use_triton_kernel = weight_config.use_triton_kernel
77+
78+
ctx.save_for_backward(_input, weight)
79+
80+
return _addmm_nvfp4_dispatch(
81+
_input,
82+
weight.t(),
83+
None, # aten_op, not used
84+
bias,
85+
)
2986

87+
@staticmethod
88+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
89+
_input, weight = ctx.saved_tensors
90+
assert isinstance(_input, NVFP4Tensor)
91+
assert isinstance(weight, NVFP4Tensor)
92+
_input = _input.to_dtype(_input._orig_dtype)
93+
weight = weight.to_dtype(weight._orig_dtype)
94+
grad_input = torch.mm(grad_output, weight)
95+
grad_weight = torch.mm(grad_output.t(), _input)
96+
return grad_input, grad_weight, None, None, None
3097

31-
class NVFP4FakeQuantizer(FakeQuantizerBase):
98+
99+
class NVFP4FakeQuantizedLinear(torch.nn.Linear):
32100
"""
33-
(Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
101+
TODO: write me
34102
"""
35103

36-
def __init__(self, config: NVFP4FakeQuantizeConfig):
37-
super().__init__()
38-
torch._C._log_api_usage_once("torchao.quantization.qat.NVFP4FakeQuantizer")
39-
self.config = config
104+
def __init__(
105+
self,
106+
in_features: int,
107+
out_features: int,
108+
bias: bool = False,
109+
activation_config: Optional[NVFP4FakeQuantizeConfig] = None,
110+
weight_config: Optional[NVFP4FakeQuantizeConfig] = None,
111+
*args,
112+
**kwargs,
113+
):
114+
super().__init__(
115+
in_features,
116+
out_features,
117+
bias,
118+
*args,
119+
**kwargs,
120+
)
121+
if weight_config is None:
122+
raise ValueError("Must specify `weight_config`")
123+
if activation_config is None:
124+
raise ValueError("Weight only NVFP4 QAT not supported yet")
125+
self.activation_config = activation_config
126+
self.weight_config = weight_config
40127

41128
def forward(self, x: torch.Tensor) -> torch.Tensor:
42-
block_size = 16
43-
original_shape = x.shape
44129
if x.dim() == 3:
130+
batch_size = x.shape[0]
45131
x = x.view(-1, x.shape[-1])
46-
if self.config.use_per_tensor_scale:
47-
tensor_amax = torch.max(torch.abs(x))
48-
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
49132
else:
50-
per_tensor_scale = None
133+
batch_size = None
134+
fq = _NVFP4FakeQuantizedForward.apply(
135+
x, self.weight, self.bias, self.activation_config, self.weight_config
136+
)
137+
assert fq.dtype == x.dtype
138+
if batch_size is not None:
139+
return fq.view(batch_size, -1, fq.shape[-1])
140+
else:
141+
return fq
51142

52-
# quantize
53-
scale, q = _nvfp4_quantize(
54-
x,
55-
block_size=block_size,
56-
per_tensor_scale=per_tensor_scale,
57-
skip_dtype_cast_and_packing=True,
143+
@classmethod
144+
def from_linear(
145+
cls,
146+
mod: torch.nn.Linear,
147+
activation_config: Optional[NVFP4FakeQuantizeConfig] = None,
148+
weight_config: Optional[NVFP4FakeQuantizeConfig] = None,
149+
):
150+
new_linear = NVFP4FakeQuantizedLinear(
151+
mod.in_features,
152+
mod.out_features,
153+
mod.bias is not None,
154+
activation_config=activation_config,
155+
weight_config=weight_config,
156+
device=mod.weight.device,
157+
dtype=mod.weight.dtype,
58158
)
59-
if self.config.use_per_tensor_scale:
60-
scale = scale * per_tensor_scale
61-
assert q.dtype == x.dtype
62-
assert scale.dtype == torch.float32
63-
64-
# dequantize
65-
M, K = q.shape[0], q.shape[1]
66-
q = q.view(M, K // block_size, block_size)
67-
scale = scale.view(M, K // block_size, 1)
68-
dq = q * scale
69-
return dq.view(original_shape).to(x.dtype)
159+
# In distributed training, the model may be instantiated
160+
# on the meta device, in which case there is no need to
161+
# copy the weights, and doing so will result in an error
162+
if mod.weight.device != torch.device("meta"):
163+
new_linear.weight = mod.weight
164+
new_linear.bias = mod.bias
165+
return new_linear

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,12 +444,16 @@ def _infer_fake_quantize_configs(
444444
elif isinstance(base_config, NVFP4InferenceConfig):
445445
if NVFP4MMConfig.DYNAMIC:
446446
act_config = NVFP4FakeQuantizeConfig(
447-
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale
447+
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale,
448+
use_swizzled_scales=False,
449+
use_triton_kernel=False,
448450
)
449451
else:
450452
act_config = None
451453
weight_config = NVFP4FakeQuantizeConfig(
452-
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale
454+
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale,
455+
use_swizzled_scales=True,
456+
use_triton_kernel=base_config.use_triton_kernel,
453457
)
454458
elif isinstance(base_config, Int8DynamicActivationIntxWeightConfig):
455459
assert base_config.version >= 2, "Only version 2+ is supported"

torchao/quantization/qat/fake_quantizer.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,12 @@ def __repr__(self) -> str:
6060

6161
@staticmethod
6262
def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase":
63-
# TODO: rewrite using registration API so we don't need to import here
64-
from torchao.prototype.qat import (
65-
NVFP4FakeQuantizeConfig,
66-
NVFP4FakeQuantizer,
67-
)
68-
6963
if isinstance(config, IntxFakeQuantizeConfig):
7064
return IntxFakeQuantizer(config)
7165
elif isinstance(config, Int4WeightFakeQuantizeConfig):
7266
return Int4WeightFakeQuantizer(config)
7367
elif isinstance(config, Float8FakeQuantizeConfig):
7468
return Float8FakeQuantizer(config)
75-
elif isinstance(config, NVFP4FakeQuantizeConfig):
76-
return NVFP4FakeQuantizer(config)
7769
else:
7870
raise ValueError(f"Unknown config type: {config}")
7971

0 commit comments

Comments
 (0)