1414
1515from torchao .prototype .mx_formats .inference_workflow import (
1616 MXDynamicActivationMXWeightConfig ,
17- NVFP4InferenceConfig ,
18- NVFP4MMConfig ,
17+ NVFP4DynamicActivationNVFP4WeightConfig ,
18+ NVFP4WeightOnlyConfig ,
1919)
2020from torchao .quantization import quantize_
2121from torchao .quantization .quantize_ .common import KernelPreference
@@ -138,9 +138,7 @@ def test_inference_workflow_mx(
138138)
139139@pytest .mark .parametrize ("bias" , [True , False ])
140140@pytest .mark .parametrize ("compile" , [True , False ])
141- @pytest .mark .parametrize (
142- "mm_config" , [NVFP4MMConfig .DYNAMIC , NVFP4MMConfig .WEIGHT_ONLY ]
143- )
141+ @pytest .mark .parametrize ("quant_type" , ["dynamic" , "weight_only" ])
144142@pytest .mark .parametrize ("inpt_dtype" , [torch .bfloat16 , torch .float32 ])
145143@pytest .mark .parametrize ("use_triton_kernel" , [True , False ])
146144@pytest .mark .parametrize ("use_dynamic_per_tensor_scale" , [True , False ])
@@ -164,7 +162,7 @@ def test_inference_workflow_mx(
164162def test_inference_workflow_nvfp4 (
165163 bias : bool ,
166164 compile : bool ,
167- mm_config : NVFP4MMConfig ,
165+ quant_type : str ,
168166 inpt_dtype : torch .dtype ,
169167 use_triton_kernel : bool ,
170168 use_dynamic_per_tensor_scale : bool ,
@@ -177,14 +175,16 @@ def test_inference_workflow_nvfp4(
177175 Tests both DYNAMIC and WEIGHT_ONLY mm_config modes
178176 """
179177 # DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs
180- if mm_config == NVFP4MMConfig . DYNAMIC and not is_sm_at_least_100 ():
178+ if quant_type == "dynamic" and not is_sm_at_least_100 ():
181179 pytest .skip ("CUDA capability >= 10.0 required for DYNAMIC float4 gemm" )
182180
183181 if bias and inpt_dtype == torch .float32 :
184182 pytest .xfail ("Bias is not supported when module weight is in fp32" )
185183
186- if mm_config == NVFP4MMConfig .WEIGHT_ONLY and compile :
187- pytest .skip ("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile" )
184+ if quant_type == "weight_only" and compile :
185+ pytest .skip ("TODO: weight_only quant currently errors w/ compile" )
186+ if quant_type == "weight_only" and use_triton_kernel :
187+ pytest .skip ("unsupported configuration" )
188188
189189 if use_inference_mode and (
190190 shapes != (128 , 64 , 256 ) or inpt_dtype != torch .bfloat16 or use_triton_kernel
@@ -200,11 +200,15 @@ def test_inference_workflow_nvfp4(
200200 m = nn .Linear (in_features , out_features , bias = bias , dtype = inpt_dtype , device = "cuda" )
201201 m_mx = copy .deepcopy (m )
202202
203- config = NVFP4InferenceConfig (
204- mm_config = mm_config ,
205- use_triton_kernel = use_triton_kernel ,
206- use_dynamic_per_tensor_scale = use_dynamic_per_tensor_scale ,
207- )
203+ if quant_type == "dynamic" :
204+ config = NVFP4DynamicActivationNVFP4WeightConfig (
205+ use_triton_kernel = use_triton_kernel ,
206+ use_dynamic_per_tensor_scale = use_dynamic_per_tensor_scale ,
207+ )
208+ else :
209+ config = NVFP4WeightOnlyConfig (
210+ use_dynamic_per_tensor_scale = use_dynamic_per_tensor_scale ,
211+ )
208212 quantize_ (m_mx , config = config )
209213
210214 if compile :
@@ -216,7 +220,7 @@ def test_inference_workflow_nvfp4(
216220
217221 y_ref = m (x )
218222
219- if use_triton_kernel and mm_config != NVFP4MMConfig . WEIGHT_ONLY :
223+ if use_triton_kernel and quant_type == "dynamic" :
220224 with cuda_kernel_profiler ("quantize_nvfp4_triton_kernel" ) as result :
221225 y_mx = m_mx (x )
222226 assert result ["found" ], "Expected quantize_nvfp4 kernel to be found"
@@ -229,14 +233,14 @@ def test_inference_workflow_nvfp4(
229233
230234 sqnr = compute_error (y_ref , y_mx )
231235
232- if mm_config == NVFP4MMConfig . WEIGHT_ONLY :
236+ if quant_type == "weight_only" :
233237 SQNR_THRESHOLD = 18.0
234238 else :
235239 SQNR_THRESHOLD = 15.0
236240
237241 assert y_mx .dtype == inpt_dtype , f"Got { y_mx .dtype } for inpt_dtype={ inpt_dtype } "
238242 assert sqnr >= SQNR_THRESHOLD , (
239- f"Got a sqnr of { sqnr } for NVFP4 recipe with bias={ bias } , mm_config= { mm_config } "
243+ f"Got a sqnr of { sqnr } for NVFP4 recipe with bias={ bias } , { quant_type = } "
240244 )
241245
242246
@@ -273,9 +277,7 @@ def test_narrow_similar_to_vllm(self):
273277 reason = "torch.compile requires PyTorch 2.8+" ,
274278 )
275279 def test_nvfp4_quantize_3d_param_similar_to_vllm (self ):
276- config = NVFP4InferenceConfig (
277- mm_config = NVFP4MMConfig .WEIGHT_ONLY ,
278- use_triton_kernel = False ,
280+ config = NVFP4WeightOnlyConfig (
279281 use_dynamic_per_tensor_scale = False ,
280282 )
281283 self ._test_quantize_3d_param_similar_to_vllm (config )
0 commit comments