Skip to content

Commit a024e29

Browse files
committed
Update on "Add NVFP4 QAT"
**Summary:** This commit adds a QAT flow for NVFP4, following the numerics in `NVFP4Tensor` closely but without the dtyping casting, swizzling, and the packing/unpacking. Users can call this flow as follows: ``` from torchao.quantization import quantize_ from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig qat_config = QATConfig( activation_config=NVFP4FakeQuantizeConfig(), weight_config=NVFP4FakeQuantizeConfig(), step="prepare", ) quantize_(model, qat_config) ``` **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 ``` Initial benchmarks on fine-tuning Qwen3-1.7B on alpaca for 3 epochs: ``` # Without QAT | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 0.8322|± | N/A| | | |none |None |byte_perplexity|↓ | 1.7804|± | N/A| | | |none |None |word_perplexity|↓ |21.8611|± | N/A| # With QAT | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 0.8271|± | N/A| | | |none |None |byte_perplexity|↓ | 1.7741|± | N/A| | | |none |None |word_perplexity|↓ |21.4467|± | N/A| ``` [ghstack-poisoned]
2 parents 80cc501 + bc2a059 commit a024e29

File tree

7 files changed

+93
-71
lines changed

7 files changed

+93
-71
lines changed

docs/source/api_ref_qat.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,3 @@ Prototype
6262
:nosignatures:
6363

6464
initialize_fake_quantizers
65-
NVFP4FakeQuantizeConfig
66-
NVFP4FakeQuantizer

test/quantization/test_qat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
from torchao.quantization.qat.fake_quantize_config import (
5151
Float8FakeQuantizeConfig,
5252
IntxFakeQuantizeConfig,
53-
NVFP4FakeQuantizeConfig,
5453
)
5554
from torchao.quantization.qat.fake_quantizer import (
5655
Float8FakeQuantizer,
@@ -1974,6 +1973,8 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
19741973
"""
19751974
Test QAT with `NVFP4FakeQuantizeConfig`.
19761975
"""
1976+
from torchao.prototype.qat import NVFP4FakeQuantizeConfig
1977+
19771978
torch.manual_seed(self.SEED)
19781979
m = M().cuda()
19791980
baseline_model = copy.deepcopy(m)

torchao/prototype/qat/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Temporary location for prototype QAT features that will
2+
# eventually live in torchao/quantization/qat
3+
4+
from .nvfp4 import (
5+
NVFP4FakeQuantizeConfig,
6+
NVFP4FakeQuantizer,
7+
)
8+
9+
__all__ = [
10+
"NVFP4FakeQuantizeConfig",
11+
"NVFP4FakeQuantizer",
12+
]

torchao/prototype/qat/nvfp4.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from dataclasses import dataclass
2+
3+
import torch
4+
5+
from torchao.prototype.mx_formats.nvfp4_tensor import (
6+
_nvfp4_quantize,
7+
per_tensor_amax_to_scale,
8+
)
9+
from torchao.quantization.qat import (
10+
FakeQuantizeConfigBase,
11+
FakeQuantizerBase,
12+
)
13+
14+
15+
@dataclass
16+
class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
17+
"""
18+
Config for fake quantizing weights or activations to NVIDIA's NVFP4 format
19+
according to https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.
20+
21+
Fake quantization numerics follow `NVFP4Tensor` closely: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/nvfp4_tensor.py.
22+
23+
Args:
24+
use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
25+
after the initial fp8 (e4m3) block-wise scaling (default True)
26+
"""
27+
28+
use_per_tensor_scale: bool = True
29+
30+
31+
class NVFP4FakeQuantizer(FakeQuantizerBase):
32+
"""
33+
(Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
34+
"""
35+
36+
def __init__(self, config: NVFP4FakeQuantizeConfig):
37+
super().__init__()
38+
torch._C._log_api_usage_once("torchao.quantization.qat.NVFP4FakeQuantizer")
39+
self.config = config
40+
41+
def forward(self, x: torch.Tensor) -> torch.Tensor:
42+
block_size = 16
43+
original_shape = x.shape
44+
if x.dim() == 3:
45+
x = x.view(-1, x.shape[-1])
46+
if self.config.use_per_tensor_scale:
47+
tensor_amax = torch.max(torch.abs(x))
48+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
49+
else:
50+
per_tensor_scale = None
51+
52+
# quantize
53+
scale, q = _nvfp4_quantize(
54+
x,
55+
block_size=block_size,
56+
per_tensor_scale=per_tensor_scale,
57+
skip_dtype_cast_and_packing=True,
58+
)
59+
if self.config.use_per_tensor_scale:
60+
scale = scale * per_tensor_scale
61+
assert q.dtype == x.dtype
62+
assert scale.dtype == torch.float32
63+
64+
# dequantize
65+
M, K = q.shape[0], q.shape[1]
66+
q = q.view(M, K // block_size, block_size)
67+
scale = scale.view(M, K // block_size, 1)
68+
dq = q * scale
69+
return dq.view(original_shape).to(x.dtype)

torchao/quantization/qat/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,12 @@
1717
FakeQuantizeConfigBase,
1818
Float8FakeQuantizeConfig,
1919
IntxFakeQuantizeConfig,
20-
NVFP4FakeQuantizeConfig,
2120
)
2221
from .fake_quantizer import (
2322
FakeQuantizer,
2423
FakeQuantizerBase,
2524
Float8FakeQuantizer,
2625
IntxFakeQuantizer,
27-
NVFP4FakeQuantizer,
2826
)
2927
from .linear import (
3028
FakeQuantizedLinear,
@@ -42,8 +40,6 @@
4240
"Float8FakeQuantizer",
4341
"IntxFakeQuantizeConfig",
4442
"IntxFakeQuantizer",
45-
"NVFP4FakeQuantizeConfig",
46-
"NVFP4FakeQuantizer",
4743
"FakeQuantizedLinear",
4844
"FakeQuantizedEmbedding",
4945
# Prototype

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -77,22 +77,6 @@ def __post_init__(self):
7777
)
7878

7979

80-
@dataclass
81-
class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
82-
"""
83-
(Prototype) Config for fake quantizing weights or activations to NVIDIA's NVFP4 format
84-
according to https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.
85-
86-
Fake quantization numerics follow `NVFP4Tensor` closely: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/nvfp4_tensor.py.
87-
88-
Args:
89-
use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
90-
after the initial fp8 (e4m3) block-wise scaling (default True)
91-
"""
92-
93-
use_per_tensor_scale: bool = True
94-
95-
9680
@dataclass
9781
class IntxFakeQuantizeConfig(FakeQuantizeConfigBase):
9882
"""
@@ -336,7 +320,6 @@ def __post_init__(self):
336320
_log_deprecation_warning(self)
337321

338322

339-
# TODO: rewrite using registration API?
340323
def _infer_fake_quantize_configs(
341324
base_config: AOBaseConfig,
342325
) -> Tuple[Optional[FakeQuantizeConfigBase], Optional[FakeQuantizeConfigBase]]:
@@ -347,11 +330,15 @@ def _infer_fake_quantize_configs(
347330
348331
Return a 2-tuple of (activation_config, weight_config) for fake quantization.
349332
"""
333+
# TODO: rewrite using registration API so we don't need to import here
350334
# avoid circular imports
351335
from torchao.prototype.mx_formats import (
352336
NVFP4InferenceConfig,
353337
NVFP4MMConfig,
354338
)
339+
from torchao.prototype.qat import (
340+
NVFP4FakeQuantizeConfig,
341+
)
355342
from torchao.quantization import (
356343
Float8DynamicActivationFloat8WeightConfig,
357344
Float8DynamicActivationInt4WeightConfig,

torchao/quantization/qat/fake_quantizer.py

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
FakeQuantizeConfigBase,
3535
Float8FakeQuantizeConfig,
3636
IntxFakeQuantizeConfig,
37-
NVFP4FakeQuantizeConfig,
3837
)
3938
from .utils import (
4039
_fake_quantize_per_channel_group,
@@ -58,6 +57,12 @@ def __repr__(self) -> str:
5857

5958
@staticmethod
6059
def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase":
60+
# TODO: rewrite using registration API so we don't need to import here
61+
from torchao.prototype.qat import (
62+
NVFP4FakeQuantizeConfig,
63+
NVFP4FakeQuantizer,
64+
)
65+
6166
if isinstance(config, IntxFakeQuantizeConfig):
6267
return IntxFakeQuantizer(config)
6368
elif isinstance(config, Float8FakeQuantizeConfig):
@@ -95,52 +100,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
95100
return dq
96101

97102

98-
class NVFP4FakeQuantizer(FakeQuantizerBase):
99-
"""
100-
(Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
101-
"""
102-
103-
def __init__(self, config: NVFP4FakeQuantizeConfig):
104-
super().__init__()
105-
torch._C._log_api_usage_once("torchao.quantization.qat.NVFP4FakeQuantizer")
106-
self.config = config
107-
108-
def forward(self, x: torch.Tensor) -> torch.Tensor:
109-
from torchao.prototype.mx_formats.nvfp4_tensor import (
110-
_nvfp4_quantize,
111-
per_tensor_amax_to_scale,
112-
)
113-
114-
block_size = 16
115-
original_shape = x.shape
116-
if x.dim() == 3:
117-
x = x.view(-1, x.shape[-1])
118-
if self.config.use_per_tensor_scale:
119-
tensor_amax = torch.max(torch.abs(x))
120-
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
121-
else:
122-
per_tensor_scale = None
123-
124-
# quantize
125-
scale, q = _nvfp4_quantize(
126-
x,
127-
block_size=block_size,
128-
per_tensor_scale=per_tensor_scale,
129-
skip_dtype_cast_and_packing=True,
130-
)
131-
if self.config.use_per_tensor_scale:
132-
scale = scale * per_tensor_scale
133-
assert q.dtype == x.dtype
134-
assert scale.dtype == torch.float32
135-
136-
# dequantize
137-
M, K = q.shape[0], q.shape[1]
138-
q = q.view(M, K // block_size, block_size)
139-
scale = scale.view(M, K // block_size, 1)
140-
dq = q * scale
141-
return dq.view(original_shape).to(x.dtype)
142-
143-
144103
class IntxFakeQuantizer(FakeQuantizerBase):
145104
"""
146105
Generic module for applying integer fake quantization to a tensor, as specified in the config.

0 commit comments

Comments
 (0)