1616 NVFP4MMConfig ,
1717)
1818from torchao .prototype .mx_formats .nvfp4_tensor import (
19+ NVFP4Tensor ,
1920 QuantizeTensorToNVFP4Kwargs ,
21+ per_tensor_amax_to_scale ,
22+ unpack_uint4 ,
2023)
2124from torchao .quantization .utils import compute_error
2225from torchao .testing .utils import skip_if_rocm
4548 not torch_version_at_least ("2.8.0" ), reason = "torch.compile requires PyTorch 2.8+"
4649)
4750def test_nvfp4_reconstruction (dtype , shape , use_per_tensor_scale ):
48- from torchao .prototype .mx_formats .nvfp4_tensor import (
49- NVFP4Tensor ,
50- per_tensor_amax_to_scale ,
51- )
52-
5351 x = torch .randn (shape , dtype = dtype , device = "cuda" )
5452 if use_per_tensor_scale :
5553 tensor_amax = torch .max (torch .abs (x ))
@@ -115,7 +113,6 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):
115113 Test that NVFP4Tensor can be constructed with swizzled scales and
116114 that the _is_swizzled_scales flag is set correctly.
117115 """
118- from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
119116
120117 M , K = shape
121118 data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
@@ -153,7 +150,6 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
153150 Test that slicing works correctly with swizzled scales and maintains
154151 the swizzled state in the output tensor.
155152 """
156- from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
157153
158154 # Use larger tensor sizes that align with swizzled requirements
159155 if slice_dim == 0 :
@@ -247,7 +243,6 @@ def test_nvfp4_swizzled_scales_slicing_errors(slice_dim, slice_spec, expected_er
247243 """
248244 Test that slicing raises appropriate errors for misaligned boundaries.
249245 """
250- from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
251246
252247 M , K = 256 , 4096
253248 data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
@@ -268,7 +263,6 @@ def test_nvfp4_swizzled_scales_view_semantics():
268263 """
269264 Test that slicing maintains proper view semantics where possible.
270265 """
271- from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
272266
273267 M , K = 256 , 4096
274268 data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
@@ -295,7 +289,6 @@ def test_nvfp4_swizzled_scales_serialization():
295289 """
296290 Test that tensor flatten/unflatten preserves the swizzled scales state.
297291 """
298- from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
299292
300293 M , K = 32 , 64
301294 data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
@@ -337,7 +330,6 @@ def test_nvfp4_swizzled_scales_get_scales_method():
337330 """
338331 Test that the get_scales() method correctly unswizzles scales when needed.
339332 """
340- from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
341333
342334 M , K = 32 , 64
343335 data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
@@ -372,11 +364,6 @@ def test_nvfp4_swizzled_scales_get_scales_method():
372364@torch .no_grad ()
373365def test_triton_nvfp4_quantize_equivalence (M , N , use_per_tensor_scale , dtype ):
374366 """Test that Triton and PyTorch NVFP4 quantization produce equivalent results."""
375- from torchao .prototype .mx_formats .nvfp4_tensor import (
376- NVFP4Tensor ,
377- per_tensor_amax_to_scale ,
378- unpack_uint4 ,
379- )
380367
381368 torch .manual_seed (42 )
382369 x = torch .randn (M , N , dtype = dtype , device = "cuda" )
@@ -462,11 +449,6 @@ def test_nvfp4_matmul_with_amax(
462449 use_triton_kernel : bool ,
463450 shapes : tuple ,
464451):
465- from torchao .prototype .mx_formats .nvfp4_tensor import (
466- NVFP4Tensor ,
467- per_tensor_amax_to_scale ,
468- )
469-
470452 # DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs
471453 if mm_config == NVFP4MMConfig .DYNAMIC and not is_sm_at_least_100 ():
472454 pytest .skip ("CUDA capability >= 10.0 required for DYNAMIC float4 gemm" )
@@ -530,8 +512,6 @@ def test_nvfp4_matmul_with_amax(
530512 not torch_version_at_least ("2.8.0" ), reason = "NVFP4 requires PyTorch 2.8+"
531513)
532514def test_nvfp4_to_copy ():
533- from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
534-
535515 x = NVFP4Tensor .to_nvfp4 (torch .randn ((32 , 128 ))).cuda ()
536516 y = torch .ops .aten ._to_copy (x , dtype = torch .bfloat16 )
537517 assert torch .equal (x .qdata , y .qdata )
0 commit comments