2020 (8 , 513 , 64 ), # Non-divisible (native only)
2121 ])
2222@pytest .mark .parametrize ("seed" , [42 ])
23+ @pytest .mark .parametrize ("use_ue8m0" , [True , False ])
2324@torch .inference_mode ()
2425def test_quantfp8_group_functionality (batch_size : int , hidden_dim : int ,
25- group_size : int , seed : int ) -> None :
26+ group_size : int , seed : int ,
27+ use_ue8m0 : bool ) -> None :
2628 """Test QuantFP8 group quantization with various configurations.
2729
2830 Tests both CUDA and native implementations, column-major scales,
@@ -38,7 +40,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
3840 group_shape = GroupShape (1 , group_size )
3941 quant_op = QuantFP8 (static = False ,
4042 group_shape = group_shape ,
41- column_major_scales = False )
43+ column_major_scales = False ,
44+ use_ue8m0 = use_ue8m0 )
4245
4346 # 1. Test native implementation (always available)
4447 x_quant_native , scales_native = quant_op .forward_native (x .clone ())
@@ -48,9 +51,15 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
4851 # 2. Test column-major scales configuration
4952 quant_op_col = QuantFP8 (static = False ,
5053 group_shape = group_shape ,
51- column_major_scales = True )
54+ column_major_scales = True ,
55+ use_ue8m0 = use_ue8m0 )
5256 _ , scales_col = quant_op_col .forward_native (x .clone ())
53- assert scales_col .shape == (expected_num_groups , batch_size )
57+ assert scales_col .shape == (batch_size , expected_num_groups )
58+ assert scales_col .stride (0 ) == 1
59+ assert scales_col .stride (1 ) == batch_size
60+
61+ # Test column-major scales consistency
62+ assert torch .allclose (scales_col , scales_native , rtol = 1e-9 , atol = 1e-8 )
5463
5564 # 3. Test CUDA implementation (only for divisible dimensions)
5665 if is_divisible :
@@ -68,8 +77,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
6877
6978
7079@pytest .mark .parametrize ("seed" , [42 ])
80+ @pytest .mark .parametrize ("use_ue8m0" , [True , False ])
7181@torch .inference_mode ()
72- def test_quantfp8_group_multidimensional (seed : int ) -> None :
82+ def test_quantfp8_group_multidimensional (seed : int , use_ue8m0 : bool ) -> None :
7383 current_platform .seed_everything (seed )
7484
7585 group_size = 64
@@ -82,7 +92,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
8292 group_shape = GroupShape (1 , group_size )
8393 quant_op = QuantFP8 (static = False ,
8494 group_shape = group_shape ,
85- column_major_scales = False )
95+ column_major_scales = False ,
96+ use_ue8m0 = use_ue8m0 )
8697
8798 x_quant , scales = quant_op .forward_native (x_3d .clone ())
8899 assert x_quant .shape == x_3d .shape
@@ -91,7 +102,8 @@ def test_quantfp8_group_multidimensional(seed: int) -> None:
91102 # Test column_major_scales with multi-dim
92103 quant_op_col = QuantFP8 (static = False ,
93104 group_shape = group_shape ,
94- column_major_scales = True )
105+ column_major_scales = True ,
106+ use_ue8m0 = use_ue8m0 )
95107 _ , scales_col = quant_op_col .forward_native (x_3d .clone ())
96108 assert scales_col .shape == (batch1 , hidden_dim // group_size , batch2 )
97109
0 commit comments