Skip to content

Commit d3cbc06

Browse files
committed
7/x mx cleanup: standardize nvfp4 config names
Summary: Splits `NVFP4InferenceConfig` to 1. `NVFP4DynamicActivationNVFP4Weight` for dynamic quant 2. `NVFP4WeightOnlyConfig` for weight-only quant Test Plan: ``` pytest test/prototype/mx_formats -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f0894bc ghstack-comment-id: 3583258138 Pull-Request: #3398
1 parent 4838397 commit d3cbc06

File tree

11 files changed

+137
-101
lines changed

11 files changed

+137
-101
lines changed

benchmarks/float8/float8_inference_roofline.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@
4040
import torchao
4141
from torchao.prototype.mx_formats.inference_workflow import (
4242
MXDynamicActivationMXWeightConfig,
43-
NVFP4InferenceConfig,
44-
NVFP4MMConfig,
43+
NVFP4DynamicActivationNVFP4WeightConfig,
4544
)
4645
from torchao.prototype.mx_formats.utils import to_blocked
4746
from torchao.quantization.quant_api import (
@@ -445,8 +444,7 @@ def run(
445444
kernel_preference=KernelPreference.AUTO,
446445
)
447446
elif recipe_name == "nvfp4":
448-
config = NVFP4InferenceConfig(
449-
mm_config=NVFP4MMConfig.DYNAMIC,
447+
config = NVFP4DynamicActivationNVFP4WeightConfig(
450448
use_dynamic_per_tensor_scale=False,
451449
)
452450
else:

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
from torchao.prototype.mx_formats.inference_workflow import (
1616
MXDynamicActivationMXWeightConfig,
17-
NVFP4InferenceConfig,
18-
NVFP4MMConfig,
17+
NVFP4DynamicActivationNVFP4WeightConfig,
18+
NVFP4WeightOnlyConfig,
1919
)
2020
from torchao.quantization import quantize_
2121
from torchao.quantization.quantize_.common import KernelPreference
@@ -138,9 +138,7 @@ def test_inference_workflow_mx(
138138
)
139139
@pytest.mark.parametrize("bias", [True, False])
140140
@pytest.mark.parametrize("compile", [True, False])
141-
@pytest.mark.parametrize(
142-
"mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY]
143-
)
141+
@pytest.mark.parametrize("quant_type", ["dynamic", "weight_only"])
144142
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
145143
@pytest.mark.parametrize("use_triton_kernel", [True, False])
146144
@pytest.mark.parametrize("use_dynamic_per_tensor_scale", [True, False])
@@ -164,7 +162,7 @@ def test_inference_workflow_mx(
164162
def test_inference_workflow_nvfp4(
165163
bias: bool,
166164
compile: bool,
167-
mm_config: NVFP4MMConfig,
165+
quant_type: str,
168166
inpt_dtype: torch.dtype,
169167
use_triton_kernel: bool,
170168
use_dynamic_per_tensor_scale: bool,
@@ -177,14 +175,16 @@ def test_inference_workflow_nvfp4(
177175
Tests both DYNAMIC and WEIGHT_ONLY mm_config modes
178176
"""
179177
# DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs
180-
if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100():
178+
if quant_type == "dynamic" and not is_sm_at_least_100():
181179
pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm")
182180

183181
if bias and inpt_dtype == torch.float32:
184182
pytest.xfail("Bias is not supported when module weight is in fp32")
185183

186-
if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile:
187-
pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile")
184+
if quant_type == "weight_only" and compile:
185+
pytest.skip("TODO: weight_only quant currently errors w/ compile")
186+
if quant_type == "weight_only" and use_triton_kernel:
187+
pytest.skip("unsupported configuration")
188188

189189
if use_inference_mode and (
190190
shapes != (128, 64, 256) or inpt_dtype != torch.bfloat16 or use_triton_kernel
@@ -200,11 +200,15 @@ def test_inference_workflow_nvfp4(
200200
m = nn.Linear(in_features, out_features, bias=bias, dtype=inpt_dtype, device="cuda")
201201
m_mx = copy.deepcopy(m)
202202

203-
config = NVFP4InferenceConfig(
204-
mm_config=mm_config,
205-
use_triton_kernel=use_triton_kernel,
206-
use_dynamic_per_tensor_scale=use_dynamic_per_tensor_scale,
207-
)
203+
if quant_type == "dynamic":
204+
config = NVFP4DynamicActivationNVFP4WeightConfig(
205+
use_triton_kernel=use_triton_kernel,
206+
use_dynamic_per_tensor_scale=use_dynamic_per_tensor_scale,
207+
)
208+
else:
209+
config = NVFP4WeightOnlyConfig(
210+
use_dynamic_per_tensor_scale=use_dynamic_per_tensor_scale,
211+
)
208212
quantize_(m_mx, config=config)
209213

210214
if compile:
@@ -216,7 +220,7 @@ def test_inference_workflow_nvfp4(
216220

217221
y_ref = m(x)
218222

219-
if use_triton_kernel and mm_config != NVFP4MMConfig.WEIGHT_ONLY:
223+
if use_triton_kernel and quant_type == "dynamic":
220224
with cuda_kernel_profiler("quantize_nvfp4_triton_kernel") as result:
221225
y_mx = m_mx(x)
222226
assert result["found"], "Expected quantize_nvfp4 kernel to be found"
@@ -229,14 +233,14 @@ def test_inference_workflow_nvfp4(
229233

230234
sqnr = compute_error(y_ref, y_mx)
231235

232-
if mm_config == NVFP4MMConfig.WEIGHT_ONLY:
236+
if quant_type == "weight_only":
233237
SQNR_THRESHOLD = 18.0
234238
else:
235239
SQNR_THRESHOLD = 15.0
236240

237241
assert y_mx.dtype == inpt_dtype, f"Got {y_mx.dtype} for inpt_dtype={inpt_dtype}"
238242
assert sqnr >= SQNR_THRESHOLD, (
239-
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}"
243+
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, {quant_type=}"
240244
)
241245

242246

@@ -273,9 +277,7 @@ def test_narrow_similar_to_vllm(self):
273277
reason="torch.compile requires PyTorch 2.8+",
274278
)
275279
def test_nvfp4_quantize_3d_param_similar_to_vllm(self):
276-
config = NVFP4InferenceConfig(
277-
mm_config=NVFP4MMConfig.WEIGHT_ONLY,
278-
use_triton_kernel=False,
280+
config = NVFP4WeightOnlyConfig(
279281
use_dynamic_per_tensor_scale=False,
280282
)
281283
self._test_quantize_3d_param_similar_to_vllm(config)

test/prototype/mx_formats/test_mx_serialization.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414

1515
from torchao.prototype.mx_formats.inference_workflow import (
1616
MXDynamicActivationMXWeightConfig,
17-
NVFP4InferenceConfig,
18-
NVFP4MMConfig,
17+
NVFP4DynamicActivationNVFP4WeightConfig,
1918
)
2019
from torchao.quantization import quantize_
2120
from torchao.quantization.quantize_.common import KernelPreference
@@ -48,8 +47,7 @@ def test_serialization(recipe_name):
4847
)
4948
else:
5049
assert recipe_name == "nvfp4", "unsupported"
51-
config = NVFP4InferenceConfig(
52-
mm_config=NVFP4MMConfig.DYNAMIC,
50+
config = NVFP4DynamicActivationNVFP4WeightConfig(
5351
use_triton_kernel=False,
5452
use_dynamic_per_tensor_scale=False,
5553
)

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
from torchao.prototype.mx_formats.constants import (
1313
F4_E2M1_MAX,
1414
)
15-
from torchao.prototype.mx_formats.inference_workflow import (
16-
NVFP4MMConfig,
17-
)
1815
from torchao.prototype.mx_formats.nvfp4_tensor import (
1916
NVFP4Tensor,
2017
QuantizeTensorToNVFP4Kwargs,
@@ -422,7 +419,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
422419
)
423420
@pytest.mark.parametrize("use_gelu", [True, False])
424421
@pytest.mark.parametrize(
425-
"mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY]
422+
"quant_type",
423+
["dynamic", "weight_only"],
426424
)
427425
@pytest.mark.parametrize("compile", [False])
428426
@pytest.mark.parametrize("bias", [True, False])
@@ -448,22 +446,22 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
448446
)
449447
def test_nvfp4_matmul_with_amax(
450448
use_gelu: bool,
451-
mm_config: NVFP4MMConfig,
449+
quant_type: str,
452450
compile: bool,
453451
bias: bool,
454452
inpt_dtype: torch.dtype,
455453
use_triton_kernel: bool,
456454
shapes: tuple,
457455
):
458456
# DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs
459-
if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100():
457+
if quant_type == "dynamic" and not is_sm_at_least_100():
460458
pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm")
461459

462460
if bias and inpt_dtype == torch.float32:
463461
pytest.xfail("Bias is not supported when module weight is in fp32")
464462

465-
if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile:
466-
pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile")
463+
if quant_type == "weight_only" and compile:
464+
pytest.skip("TODO: weight_only currently errors w/ compile")
467465

468466
m, k, n = shapes
469467

@@ -483,7 +481,7 @@ def test_nvfp4_matmul_with_amax(
483481
a_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(A)))
484482
b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B)))
485483
act_quant_kwargs = None
486-
if mm_config == NVFP4MMConfig.DYNAMIC:
484+
if quant_type == "dynamic":
487485
act_quant_kwargs = QuantizeTensorToNVFP4Kwargs()
488486
A_nvfp4 = NVFP4Tensor.to_nvfp4(
489487
A,
@@ -509,7 +507,7 @@ def test_nvfp4_matmul_with_amax(
509507
sqnr = compute_error(C_ref, C_nvfp4)
510508
SQNR_THRESHOLD = 16.0
511509
assert sqnr >= SQNR_THRESHOLD, (
512-
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, mm_config={mm_config}, compile={compile}, bias={bias}"
510+
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, {quant_type=}, compile={compile}, bias={bias}"
513511
)
514512

515513

test/quantization/test_qat.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2085,13 +2085,15 @@ def test_infer_int4_weight_only_config(self):
20852085
def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool):
20862086
"""
20872087
Test the following:
2088-
quantize_(model, QATConfig(NVFP4InferenceConfig(), step="prepare"))
2089-
quantize_(model, QATConfig(NVFP4InferenceConfig(), step="convert"))
2088+
quantize_(model, QATConfig(NVFP4DynamicActivationNVFP4WeightConfig(), step="prepare"))
2089+
quantize_(model, QATConfig(NVFP4DynamicActivationNVFP4WeightConfig(), step="convert"))
20902090
"""
2091-
from torchao.prototype.mx_formats import NVFP4InferenceConfig
2091+
from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig
20922092

20932093
self._test_quantize_api_against_ptq(
2094-
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
2094+
NVFP4DynamicActivationNVFP4WeightConfig(
2095+
use_dynamic_per_tensor_scale=use_per_tensor_scale
2096+
),
20952097
target_prepare_sqnr=float("inf"),
20962098
target_convert_sqnr=float("inf"),
20972099
)
@@ -2103,15 +2105,17 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
21032105
"""
21042106
Test QAT with `NVFP4FakeQuantizeConfig`.
21052107
"""
2106-
from torchao.prototype.mx_formats import NVFP4InferenceConfig
2108+
from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig
21072109
from torchao.prototype.qat import NVFP4FakeQuantizeConfig
21082110

21092111
torch.manual_seed(self.SEED)
21102112
m = M().cuda()
21112113
baseline_model = copy.deepcopy(m)
21122114
quantize_(
21132115
baseline_model,
2114-
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
2116+
NVFP4DynamicActivationNVFP4WeightConfig(
2117+
use_dynamic_per_tensor_scale=use_per_tensor_scale
2118+
),
21152119
)
21162120
qat_config = QATConfig(
21172121
activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),

torchao/prototype/mx_formats/README.md

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ from torchao.quantization import quantize_
109109
import torchao.prototype.mx_formats
110110
from torchao.prototype.mx_formats.inference_workflow import (
111111
MXDynamicActivationMXWeightConfig,
112-
NVFP4InferenceConfig,
113-
NVFP4MMConfig,
112+
NVFP4DynamicActivationNVFP4WeightConfig,
113+
NVFP4WeightOnlyConfig,
114114
)
115115
from torchao.quantization.quantize_.common import KernelPreference
116116

@@ -129,6 +129,27 @@ quantize_(m_mxfp8, config=config)
129129
m_mxfp8 = torch.compile(m_mxfp8, fullgraph=True)
130130
y_mxfp8 = m_mxfp8(x)
131131

132+
# nvfp4 dynamic quant
133+
134+
m_nvfp4 = copy.deepcopy(m)
135+
config = NVFP4DynamicActivationNVFP4WeightConfig(
136+
use_dynamic_per_tensor_scale=True,
137+
use_triton_kernel=True,
138+
)
139+
quantize_(m_nvfp4, config=config)
140+
m_nvfp4 = torch.compile(m_nvfp4, fullgraph=True)
141+
y_nvfp4 = m_nvfp4(x)
142+
143+
# nvfp4 weight-only quant
144+
145+
m_nvfp4_wo = copy.deepcopy(m)
146+
config = NVFP4WeightOnlyConfig(
147+
use_dynamic_per_tensor_scale=True,
148+
)
149+
quantize_(m_nvfp4_wo, config=config)
150+
m_nvfp4_wo = torch.compile(m_nvfp4_wo, fullgraph=True)
151+
y_nvfp4 = m_nvfp4_wo(x)
152+
132153
# mxfp4
133154

134155
m_mxfp4 = copy.deepcopy(m)
@@ -140,17 +161,6 @@ config = MXDynamicActivationMXWeightConfig(
140161
quantize_(m_mxfp4, config=config)
141162
m_mxfp4 = torch.compile(m_mxfp4, fullgraph=True)
142163
y_mxfp4 = m_mxfp4(x)
143-
144-
# nvfp4
145-
146-
m_nvfp4 = copy.deepcopy(m)
147-
config = NVFP4InferenceConfig(
148-
mm_config=NVFP4MMConfig.DYNAMIC,
149-
use_dynamic_per_tensor_scale=True,
150-
)
151-
quantize_(m_nvfp4, config=config)
152-
m_nvfp4 = torch.compile(m_nvfp4, fullgraph=True)
153-
y_nvfp4 = m_nvfp4(x)
154164
```
155165

156166
## MXTensor

torchao/prototype/mx_formats/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
# Note: Prototype and subject to change
77
from torchao.prototype.mx_formats.inference_workflow import (
88
MXDynamicActivationMXWeightConfig,
9-
NVFP4InferenceConfig,
10-
NVFP4MMConfig,
9+
NVFP4DynamicActivationNVFP4WeightConfig,
10+
NVFP4WeightOnlyConfig,
1111
)
1212

1313
# import mx_linear here to register the quantize_ transform logic
@@ -18,6 +18,6 @@
1818
"MXLinearConfig",
1919
"MXLinearRecipeName",
2020
"MXDynamicActivationMXWeightConfig",
21-
"NVFP4InferenceConfig",
22-
"NVFP4MMConfig",
21+
"NVFP4DynamicActivationNVFP4WeightConfig",
22+
"NVFP4WeightOnlyConfig",
2323
]

0 commit comments

Comments
 (0)