2020 (8 , 513 , 64 ), # Non-divisible (native only)
2121 ])
2222@pytest .mark .parametrize ("seed" , [42 ])
23- @pytest .mark .parametrize ("use_ue8m0" , [True , False ])
2423@torch .inference_mode ()
2524def test_quantfp8_group_functionality (batch_size : int , hidden_dim : int ,
26- group_size : int , seed : int ,
27- use_ue8m0 : bool ) -> None :
25+ group_size : int , seed : int ) -> None :
2826 """Test QuantFP8 group quantization with various configurations.
2927
3028 Tests both CUDA and native implementations, column-major scales,
@@ -40,8 +38,7 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
4038 group_shape = GroupShape (1 , group_size )
4139 quant_op = QuantFP8 (static = False ,
4240 group_shape = group_shape ,
43- column_major_scales = False ,
44- use_ue8m0 = use_ue8m0 )
41+ column_major_scales = False )
4542
4643 # 1. Test native implementation (always available)
4744 x_quant_native , scales_native = quant_op .forward_native (x .clone ())
@@ -51,15 +48,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
5148 # 2. Test column-major scales configuration
5249 quant_op_col = QuantFP8 (static = False ,
5350 group_shape = group_shape ,
54- column_major_scales = True ,
55- use_ue8m0 = use_ue8m0 )
51+ column_major_scales = True )
5652 _ , scales_col = quant_op_col .forward_native (x .clone ())
57- assert scales_col .shape == (batch_size , expected_num_groups )
58- assert scales_col .stride (0 ) == 1
59- assert scales_col .stride (1 ) == batch_size
60-
61- # Test column-major scales consistency
62- assert torch .allclose (scales_col , scales_native , rtol = 1e-9 , atol = 1e-8 )
53+ assert scales_col .shape == (expected_num_groups , batch_size )
6354
6455 # 3. Test CUDA implementation (only for divisible dimensions)
6556 if is_divisible :
@@ -77,9 +68,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
7768
7869
7970@pytest .mark .parametrize ("seed" , [42 ])
80- @pytest .mark .parametrize ("use_ue8m0" , [True , False ])
8171@torch .inference_mode ()
82- def test_quantfp8_group_multidimensional (seed : int , use_ue8m0 : bool ) -> None :
72+ def test_quantfp8_group_multidimensional (seed : int ) -> None :
8373 current_platform .seed_everything (seed )
8474
8575 group_size = 64
@@ -92,8 +82,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
9282 group_shape = GroupShape (1 , group_size )
9383 quant_op = QuantFP8 (static = False ,
9484 group_shape = group_shape ,
95- column_major_scales = False ,
96- use_ue8m0 = use_ue8m0 )
85+ column_major_scales = False )
9786
9887 x_quant , scales = quant_op .forward_native (x_3d .clone ())
9988 assert x_quant .shape == x_3d .shape
@@ -102,8 +91,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
10291 # Test column_major_scales with multi-dim
10392 quant_op_col = QuantFP8 (static = False ,
10493 group_shape = group_shape ,
105- column_major_scales = True ,
106- use_ue8m0 = use_ue8m0 )
94+ column_major_scales = True )
10795 _ , scales_col = quant_op_col .forward_native (x_3d .clone ())
10896 assert scales_col .shape == (batch1 , hidden_dim // group_size , batch2 )
10997
0 commit comments