1+ import functools
2+
13import pytest
24import torch
35
46from flashinfer import fp4_quantize
57from flashinfer .utils import is_sm100a_supported
68
79DTYPES = [torch .float16 , torch .bfloat16 ]
8- SHAPES = [(128 , 64 ), (128 , 128 ), (256 , 64 ), (256 , 128 )]
10+ # The batch dimension doesn't need to be multiple of 128
11+ SHAPES = [(128 , 64 ), (256 , 128 ), (120 , 64 ), (200 , 256 )]
912SEEDS = [42 ]
1013CUDA_DEVICES = ["cuda:0" ]
1114
4245BLOCK_SIZE = 16
4346
4447
48+ def swizzle_sf (
49+ unswizzled_sf : torch .Tensor ,
50+ original_row : int ,
51+ original_col : int ,
52+ scaling_vector_size : int = 16 ,
53+ ) -> torch .Tensor :
54+ """
55+ Inverse of `unswizzle_sf`. Converts an unswizzled tensor back to swizzled form.
56+
57+ Args:
58+ unswizzled_sf: Tensor of shape [row, col // scaling_vector_size].
59+ original_row: Original row dimension (e.g., 120).
60+ original_col: Original column dimension (e.g., 64).
61+ scaling_vector_size: Scaling factor (default 16).
62+
63+ Returns:
64+ Swizzled tensor of shape [padded_row, padded_col // scaling_vector_size].
65+ """
66+ unswizzled_sf = unswizzled_sf .contiguous ()
67+ factor = scaling_vector_size * 4
68+ padded_row = ((original_row + 128 - 1 ) // 128 ) * 128 # Next multiple of 128
69+ padded_col = ((original_col + factor - 1 ) // factor ) * factor # Next multiple of 64
70+
71+ # Pad the input tensor to [padded_row, padded_col // scaling_vector_size]
72+ pad_rows = padded_row - original_row
73+ pad_cols = (padded_col - original_col ) // scaling_vector_size
74+ padded_sf = torch .nn .functional .pad (
75+ unswizzled_sf ,
76+ (0 , pad_cols , 0 , pad_rows ),
77+ mode = "constant" ,
78+ value = 0 ,
79+ ).contiguous ()
80+
81+ # Reshape and transpose to reverse unswizzle_sf
82+ num_m_tiles = padded_row // 128
83+ num_k_tiles = padded_col // factor
84+ sf_reshaped = padded_sf .view (num_m_tiles , 4 , 32 , num_k_tiles , 4 ) # Reverse reshape
85+ sf_swizzled = sf_reshaped .transpose (
86+ 1 , 3
87+ ) # Reverse transpose [num_m_tiles, num_k_tiles, 32, 4, 4]
88+ sf_swizzled = sf_swizzled .reshape (
89+ padded_row , padded_col // scaling_vector_size
90+ ) # Flatten to [128, 64]
91+
92+ return sf_swizzled .contiguous ()
93+
94+
95+ def unswizzle_sf (
96+ sf : torch .Tensor , row : int , col : int , scaling_vector_size : int = 16
97+ ) -> torch .Tensor :
98+ factor = scaling_vector_size * 4
99+ num_m_tiles = (row + 128 - 1 ) // 128
100+ num_k_tiles = (col + factor - 1 ) // factor
101+ # SF layout [num_m_tiles, num_k_tiles, 32 (m_tile column major), 4 (m_tile column major), 4(k_tile)]
102+ sf_reshaped = sf .view (num_m_tiles , num_k_tiles , 32 , 4 , 4 )
103+ sf_unswizzle = sf_reshaped .transpose (1 , 3 )
104+ sf_unswizzle = sf_unswizzle .reshape (num_m_tiles * 32 * 4 , num_k_tiles * 4 )
105+ sf_unswizzle_sliced = sf_unswizzle [:row , : (col // scaling_vector_size )]
106+ return sf_unswizzle_sliced .contiguous ()
107+
108+
45109def cast_from_fp4 (x , m , n ):
46110 # The fp4 values are packed in uint8 as [v_1st | v_2nd]
47111 v_2nd = x & 0xF
@@ -107,23 +171,24 @@ def recover_swizzled_scales(scale, m, n):
107171@pytest .mark .parametrize ("seed" , SEEDS )
108172@pytest .mark .parametrize ("device" , CUDA_DEVICES )
109173@torch .inference_mode ()
110- def test_quantize_to_fp4 (
174+ def test_fp4_quantization (
111175 dtype : torch .dtype ,
112176 shape : tuple [int , int ],
113177 seed : int ,
114178 device : str ,
115179) -> None :
116- if not is_sm100a_supported (torch .device ("cuda" )):
180+ if not is_sm100a_supported (torch .device (device )):
117181 pytest .skip ("Nvfp4 Requires compute capability of 10 or above" )
118182 torch .set_default_device (device )
183+ torch .manual_seed (seed )
119184 m , n = shape
120185 x = torch .randn ((m , n ), dtype = dtype )
121186 tensor_amax = torch .abs (x ).max ().to (torch .float32 )
122187 global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
123188 out_ref , scale_ref = ref_nvfp4_quant (x , global_scale )
124189
125190 out , out_scale = fp4_quantize (x , global_scale , BLOCK_SIZE , False )
126- assert ( n % BLOCK_SIZE == 0 , f"cols needs to be { BLOCK_SIZE } divisible" )
191+ assert n % BLOCK_SIZE == 0 , f"cols needs to be { BLOCK_SIZE } divisible"
127192 scale_ans = recover_swizzled_scales (
128193 out_scale .reshape (- 1 , n // BLOCK_SIZE ).view (torch .float8_e4m3fn ), m , n
129194 )
@@ -132,5 +197,42 @@ def test_quantize_to_fp4(
132197 torch .testing .assert_close (scale_ans , scale_ref , rtol = 1e-1 , atol = 1e-1 )
133198
134199
200+ @pytest .mark .parametrize ("dtype" , DTYPES )
201+ @pytest .mark .parametrize ("shape" , SHAPES )
202+ @pytest .mark .parametrize ("seed" , SEEDS )
203+ @pytest .mark .parametrize ("device" , CUDA_DEVICES )
204+ @torch .inference_mode ()
205+ def test_scale_swizzling (
206+ dtype : torch .dtype ,
207+ shape : tuple [int , int ],
208+ seed : int ,
209+ device : str ,
210+ ) -> None :
211+ if not is_sm100a_supported (torch .device ("cuda" )):
212+ pytest .skip ("Nvfp4 Requires compute capability of 10 or above" )
213+ torch .set_default_device (device )
214+ torch .manual_seed (seed )
215+ m , n = shape
216+ x = torch .randn ((m , n ), dtype = dtype )
217+ tensor_amax = torch .abs (x ).max ().to (torch .float32 )
218+ global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
219+
220+ _ , unswizzled_scale = fp4_quantize (x , global_scale , BLOCK_SIZE , False , False )
221+ _ , swizzled_scale = fp4_quantize (x , global_scale , BLOCK_SIZE , False , True )
222+ assert n % BLOCK_SIZE == 0 , f"cols needs to be { BLOCK_SIZE } divisible"
223+ recovered_unswizzled_scale = unswizzle_sf (
224+ swizzle_sf (unswizzled_scale , m , n ),
225+ m ,
226+ n ,
227+ )
228+
229+ # We don't expect the following since padding:
230+ # swizzle_sf(unswizzled_scale) == swizzled_scale
231+ ref_unswizzled_scale = unswizzle_sf (swizzled_scale , m , n )
232+ assert_equal = functools .partial (torch .testing .assert_close , rtol = 0 , atol = 0 )
233+ assert_equal (recovered_unswizzled_scale , unswizzled_scale )
234+ assert_equal (ref_unswizzled_scale , unswizzled_scale )
235+
236+
135237if __name__ == "__main__" :
136238 pytest .main ([__file__ , "-v" ])
0 commit comments