Skip to content

Commit a369ffa

Browse files
committed
Support NVFP4 dynamic per tensor scale
**Summary:** This commit adds an option for the existing `NVFP4InferenceConfig` to dynamically compute an appropriate fp32 per tensor scale to support the two level scaling according to the NVFP4 specification: https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/. While two level scaling is supported in `NVFP4Tensor`, today there is no config API for users to call this. The existing `NVFP4InferenceConfig` only supports single level scaling because including an explicit `per_tensor_scale` field would make serialization tricky. In the future, we should add an end-to-end calibration flow so users can compute an appropriate per tensor scale for the activations first, and then pass this to `NVFP4Tensor` as a static scale, similar to the proposal in #2572. **Test Plan:** ``` pytest test/prototype/mx_formats/test_inference_workflow.py -k test_inference_workflow_nvfp4 pytest test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` Also did a quick benchmark before and after: ``` import copy import time import torch from torchao.quantization import quantize_ from torchao.prototype.mx_formats import NVFP4InferenceConfig m_mx1 = torch.nn.Linear(64, 256, bias=True, dtype=torch.bfloat16, device="cuda") m_mx2 = copy.deepcopy(m_mx1) config1 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=False) config2 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=True) quantize_(m_mx1, config=config1) quantize_(m_mx2, config=config2) m_mx1 = torch.compile(m_mx1, fullgraph=True, backend="aot_eager") m_mx2 = torch.compile(m_mx2, fullgraph=True, backend="aot_eager") start = time.time() for _ in range(1000): m_mx1(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)) print("No per_tensor_scale = ", time.time() - start, "seconds") start = time.time() for _ in range(1000): m_mx2(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)) print("With per_tensor_scale = ", time.time() - start, "seconds") ``` On a single B200: ``` No per_tensor_scale = 1.2855589389801025 seconds With per_tensor_scale = 1.3009123802185059 seconds ``` ghstack-source-id: e6c06b6 Pull Request resolved: #3049
1 parent 8e2ca35 commit a369ffa

File tree

5 files changed

+49
-17
lines changed

5 files changed

+49
-17
lines changed

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool):
105105
)
106106
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
107107
@pytest.mark.parametrize("use_triton_kernel", [True, False])
108+
@pytest.mark.parametrize("use_dynamic_per_tensor_scale", [True, False])
108109
@pytest.mark.parametrize(
109110
"shapes",
110111
[
@@ -126,6 +127,7 @@ def test_inference_workflow_nvfp4(
126127
mm_config: NVFP4MMConfig,
127128
inpt_dtype: torch.dtype,
128129
use_triton_kernel: bool,
130+
use_dynamic_per_tensor_scale: bool,
129131
shapes: tuple,
130132
):
131133
"""
@@ -147,7 +149,9 @@ def test_inference_workflow_nvfp4(
147149
m_mx = copy.deepcopy(m)
148150

149151
config = NVFP4InferenceConfig(
150-
mm_config=mm_config, use_triton_kernel=use_triton_kernel
152+
mm_config=mm_config,
153+
use_triton_kernel=use_triton_kernel,
154+
use_dynamic_per_tensor_scale=use_dynamic_per_tensor_scale,
151155
)
152156
quantize_(m_mx, config=config)
153157

test/quantization/test_qat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2077,7 +2077,8 @@ def test_infer_int4_weight_only_config(self):
20772077
self.assertEqual(weight_config.activation_dtype, torch.bfloat16)
20782078

20792079
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
2080-
def test_quantize_api_nvfp4(self):
2080+
@parametrize("use_per_tensor_scale", [True, False])
2081+
def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool):
20812082
"""
20822083
Test the following:
20832084
quantize_(model, QATConfig(NVFP4InferenceConfig(), step="prepare"))
@@ -2086,8 +2087,8 @@ def test_quantize_api_nvfp4(self):
20862087
from torchao.prototype.mx_formats import NVFP4InferenceConfig
20872088

20882089
self._test_quantize_api_against_ptq(
2089-
NVFP4InferenceConfig(),
2090-
target_prepare_sqnr=8,
2090+
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
2091+
target_prepare_sqnr=12,
20912092
target_convert_sqnr=float("inf"),
20922093
)
20932094

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
NVFP4MMConfig,
2323
NVFP4Tensor,
2424
QuantizeTensorToNVFP4Kwargs,
25+
per_tensor_amax_to_scale,
2526
)
2627
from torchao.quantization.transform_module import (
2728
register_quantize_module_handler,
@@ -134,7 +135,8 @@ class NVFP4InferenceConfig(AOBaseConfig):
134135
This is a specialized configuration for NVIDIA's FP4 format.
135136
Configuration parameters:
136137
- mm_config: NVFP4MMConfig, which can be set to DYNAMIC or WEIGHT_ONLY (emulated mm in high precision)
137-
- use_triton_kernel: bool, whether to use fused triton kernel for activation scaling (default: False)
138+
- use_triton_kernel: bool, whether to use fused triton kernel for activation scaling (default: True)
139+
- use_dynamic_per_tensor_scale: bool, whether to dynamically compute per tensor scale (default: True)
138140
- Data: float4_e2m1fn_x2
139141
- Scales: float8_e4m3fn
140142
- Block size: 16 along the reduction dim
@@ -145,6 +147,7 @@ class NVFP4InferenceConfig(AOBaseConfig):
145147

146148
mm_config: NVFP4MMConfig = NVFP4MMConfig.DYNAMIC
147149
use_triton_kernel: bool = True
150+
use_dynamic_per_tensor_scale: bool = True
148151

149152
def __post_init__(self):
150153
# Validate PyTorch version
@@ -175,12 +178,20 @@ def _nvfp4_inference_linear_transform(
175178
"Please use bfloat16 or float16 weights, or remove the bias from the linear layer."
176179
)
177180

181+
per_tensor_scale = None
182+
if config.use_dynamic_per_tensor_scale:
183+
tensor_amax = torch.max(torch.abs(weight))
184+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
185+
178186
act_quant_kwargs = None
179187
if config.mm_config == NVFP4MMConfig.DYNAMIC:
180-
act_quant_kwargs = QuantizeTensorToNVFP4Kwargs()
188+
act_quant_kwargs = QuantizeTensorToNVFP4Kwargs(
189+
use_dynamic_per_tensor_scale=config.use_dynamic_per_tensor_scale,
190+
)
181191

182192
quantized_weight = NVFP4Tensor.to_nvfp4(
183193
weight,
194+
per_tensor_scale=per_tensor_scale,
184195
is_swizzled_scales=True,
185196
use_triton_kernel=False, # Always use traditional construction for weights
186197
act_quant_kwargs=act_quant_kwargs,

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class QuantizeTensorToNVFP4Kwargs(QuantizeTensorKwargs):
4747
block_size: int = 16
4848
is_swizzled_scales: bool = False
4949
use_triton_kernel: bool = False
50+
use_dynamic_per_tensor_scale: bool = False
5051

5152

5253
# TODO(future PR): move over to TorchAOBaseTensor's dispatch
@@ -245,7 +246,7 @@ def get_hp_scales(self) -> torch.Tensor:
245246

246247
return (
247248
scale_e4m3.to(self._orig_dtype)
248-
if not self._per_tensor_scale
249+
if self._per_tensor_scale is None
249250
else self._per_tensor_scale * scale_e4m3.to(self._orig_dtype)
250251
)
251252

@@ -645,10 +646,15 @@ def nvfp4_linear(func, types, args, kwargs):
645646
else:
646647
# dynamic quant
647648
k = weight_tensor.act_quant_kwargs
649+
if k.use_dynamic_per_tensor_scale:
650+
tensor_amax = torch.max(torch.abs(input_tensor))
651+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
652+
else:
653+
per_tensor_scale = weight_tensor._act_per_tensor_scale
648654
input_tensor = NVFP4Tensor.to_nvfp4(
649655
input_tensor,
650656
block_size=k.block_size,
651-
per_tensor_scale=weight_tensor._act_per_tensor_scale,
657+
per_tensor_scale=per_tensor_scale,
652658
is_swizzled_scales=k.is_swizzled_scales,
653659
use_triton_kernel=k.use_triton_kernel,
654660
)
@@ -672,10 +678,15 @@ def nvfp4_mm(func, types, args, kwargs):
672678
else:
673679
if not isinstance(input_tensor, NVFP4Tensor):
674680
k = weight_tensor.act_quant_kwargs
681+
if k.use_dynamic_per_tensor_scale:
682+
tensor_amax = torch.max(torch.abs(input_tensor))
683+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
684+
else:
685+
per_tensor_scale = weight_tensor._act_per_tensor_scale
675686
input_tensor = NVFP4Tensor.to_nvfp4(
676687
input_tensor,
677688
block_size=k.block_size,
678-
per_tensor_scale=weight_tensor._act_per_tensor_scale,
689+
per_tensor_scale=per_tensor_scale,
679690
is_swizzled_scales=k.is_swizzled_scales,
680691
use_triton_kernel=k.use_triton_kernel,
681692
)
@@ -697,12 +708,18 @@ def nvfp4_addmm(func, types, args, kwargs):
697708
else:
698709
return torch.addmm(bias, input_tensor, weight_dequant)
699710
else:
711+
# TODO: refactor duplicate code
700712
if not isinstance(input_tensor, NVFP4Tensor):
701713
k = weight_tensor.act_quant_kwargs
714+
if k.use_dynamic_per_tensor_scale:
715+
tensor_amax = torch.max(torch.abs(input_tensor))
716+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
717+
else:
718+
per_tensor_scale = weight_tensor._act_per_tensor_scale
702719
input_tensor = NVFP4Tensor.to_nvfp4(
703720
input_tensor,
704721
block_size=k.block_size,
705-
per_tensor_scale=weight_tensor._act_per_tensor_scale,
722+
per_tensor_scale=per_tensor_scale,
706723
is_swizzled_scales=k.is_swizzled_scales,
707724
use_triton_kernel=k.use_triton_kernel,
708725
)

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -442,16 +442,15 @@ def _infer_fake_quantize_configs(
442442
activation_dtype=e4m3_dtype,
443443
)
444444
elif isinstance(base_config, NVFP4InferenceConfig):
445-
# Note: today the PTQ config does not allow the user to specify
446-
# `per_tensor_scales` due to serialization concerns. In the future
447-
# we may add a way to compute these dynamically (for activations),
448-
# but for now QAT will mimic the existing behavior of not having
449-
# `per_tensor_scales` (subject to change)
450445
if NVFP4MMConfig.DYNAMIC:
451-
act_config = NVFP4FakeQuantizeConfig(False)
446+
act_config = NVFP4FakeQuantizeConfig(
447+
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale
448+
)
452449
else:
453450
act_config = None
454-
weight_config = NVFP4FakeQuantizeConfig(False)
451+
weight_config = NVFP4FakeQuantizeConfig(
452+
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale
453+
)
455454
elif isinstance(base_config, Int8DynamicActivationIntxWeightConfig):
456455
assert base_config.version >= 2, "Only version 2+ is supported"
457456
assert base_config.intx_packing_format == "unpacked_to_int8", (

0 commit comments

Comments
 (0)