Skip to content

Commit f230eb6

Browse files
authored
Add fp4 quantization swizzling tests (#1157)
1 parent ac78bc3 commit f230eb6

File tree

1 file changed

+106
-4
lines changed

1 file changed

+106
-4
lines changed

tests/test_fp4_quantize.py

Lines changed: 106 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import functools
2+
13
import pytest
24
import torch
35

46
from flashinfer import fp4_quantize
57
from flashinfer.utils import is_sm100a_supported
68

79
DTYPES = [torch.float16, torch.bfloat16]
8-
SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)]
10+
# The batch dimension doesn't need to be multiple of 128
11+
SHAPES = [(128, 64), (256, 128), (120, 64), (200, 256)]
912
SEEDS = [42]
1013
CUDA_DEVICES = ["cuda:0"]
1114

@@ -42,6 +45,67 @@
4245
BLOCK_SIZE = 16
4346

4447

48+
def swizzle_sf(
49+
unswizzled_sf: torch.Tensor,
50+
original_row: int,
51+
original_col: int,
52+
scaling_vector_size: int = 16,
53+
) -> torch.Tensor:
54+
"""
55+
Inverse of `unswizzle_sf`. Converts an unswizzled tensor back to swizzled form.
56+
57+
Args:
58+
unswizzled_sf: Tensor of shape [row, col // scaling_vector_size].
59+
original_row: Original row dimension (e.g., 120).
60+
original_col: Original column dimension (e.g., 64).
61+
scaling_vector_size: Scaling factor (default 16).
62+
63+
Returns:
64+
Swizzled tensor of shape [padded_row, padded_col // scaling_vector_size].
65+
"""
66+
unswizzled_sf = unswizzled_sf.contiguous()
67+
factor = scaling_vector_size * 4
68+
padded_row = ((original_row + 128 - 1) // 128) * 128 # Next multiple of 128
69+
padded_col = ((original_col + factor - 1) // factor) * factor # Next multiple of 64
70+
71+
# Pad the input tensor to [padded_row, padded_col // scaling_vector_size]
72+
pad_rows = padded_row - original_row
73+
pad_cols = (padded_col - original_col) // scaling_vector_size
74+
padded_sf = torch.nn.functional.pad(
75+
unswizzled_sf,
76+
(0, pad_cols, 0, pad_rows),
77+
mode="constant",
78+
value=0,
79+
).contiguous()
80+
81+
# Reshape and transpose to reverse unswizzle_sf
82+
num_m_tiles = padded_row // 128
83+
num_k_tiles = padded_col // factor
84+
sf_reshaped = padded_sf.view(num_m_tiles, 4, 32, num_k_tiles, 4) # Reverse reshape
85+
sf_swizzled = sf_reshaped.transpose(
86+
1, 3
87+
) # Reverse transpose [num_m_tiles, num_k_tiles, 32, 4, 4]
88+
sf_swizzled = sf_swizzled.reshape(
89+
padded_row, padded_col // scaling_vector_size
90+
) # Flatten to [128, 64]
91+
92+
return sf_swizzled.contiguous()
93+
94+
95+
def unswizzle_sf(
96+
sf: torch.Tensor, row: int, col: int, scaling_vector_size: int = 16
97+
) -> torch.Tensor:
98+
factor = scaling_vector_size * 4
99+
num_m_tiles = (row + 128 - 1) // 128
100+
num_k_tiles = (col + factor - 1) // factor
101+
# SF layout [num_m_tiles, num_k_tiles, 32 (m_tile column major), 4 (m_tile column major), 4(k_tile)]
102+
sf_reshaped = sf.view(num_m_tiles, num_k_tiles, 32, 4, 4)
103+
sf_unswizzle = sf_reshaped.transpose(1, 3)
104+
sf_unswizzle = sf_unswizzle.reshape(num_m_tiles * 32 * 4, num_k_tiles * 4)
105+
sf_unswizzle_sliced = sf_unswizzle[:row, : (col // scaling_vector_size)]
106+
return sf_unswizzle_sliced.contiguous()
107+
108+
45109
def cast_from_fp4(x, m, n):
46110
# The fp4 values are packed in uint8 as [v_1st | v_2nd]
47111
v_2nd = x & 0xF
@@ -107,23 +171,24 @@ def recover_swizzled_scales(scale, m, n):
107171
@pytest.mark.parametrize("seed", SEEDS)
108172
@pytest.mark.parametrize("device", CUDA_DEVICES)
109173
@torch.inference_mode()
110-
def test_quantize_to_fp4(
174+
def test_fp4_quantization(
111175
dtype: torch.dtype,
112176
shape: tuple[int, int],
113177
seed: int,
114178
device: str,
115179
) -> None:
116-
if not is_sm100a_supported(torch.device("cuda")):
180+
if not is_sm100a_supported(torch.device(device)):
117181
pytest.skip("Nvfp4 Requires compute capability of 10 or above")
118182
torch.set_default_device(device)
183+
torch.manual_seed(seed)
119184
m, n = shape
120185
x = torch.randn((m, n), dtype=dtype)
121186
tensor_amax = torch.abs(x).max().to(torch.float32)
122187
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
123188
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
124189

125190
out, out_scale = fp4_quantize(x, global_scale, BLOCK_SIZE, False)
126-
assert (n % BLOCK_SIZE == 0, f"cols needs to be {BLOCK_SIZE} divisible")
191+
assert n % BLOCK_SIZE == 0, f"cols needs to be {BLOCK_SIZE} divisible"
127192
scale_ans = recover_swizzled_scales(
128193
out_scale.reshape(-1, n // BLOCK_SIZE).view(torch.float8_e4m3fn), m, n
129194
)
@@ -132,5 +197,42 @@ def test_quantize_to_fp4(
132197
torch.testing.assert_close(scale_ans, scale_ref, rtol=1e-1, atol=1e-1)
133198

134199

200+
@pytest.mark.parametrize("dtype", DTYPES)
201+
@pytest.mark.parametrize("shape", SHAPES)
202+
@pytest.mark.parametrize("seed", SEEDS)
203+
@pytest.mark.parametrize("device", CUDA_DEVICES)
204+
@torch.inference_mode()
205+
def test_scale_swizzling(
206+
dtype: torch.dtype,
207+
shape: tuple[int, int],
208+
seed: int,
209+
device: str,
210+
) -> None:
211+
if not is_sm100a_supported(torch.device("cuda")):
212+
pytest.skip("Nvfp4 Requires compute capability of 10 or above")
213+
torch.set_default_device(device)
214+
torch.manual_seed(seed)
215+
m, n = shape
216+
x = torch.randn((m, n), dtype=dtype)
217+
tensor_amax = torch.abs(x).max().to(torch.float32)
218+
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
219+
220+
_, unswizzled_scale = fp4_quantize(x, global_scale, BLOCK_SIZE, False, False)
221+
_, swizzled_scale = fp4_quantize(x, global_scale, BLOCK_SIZE, False, True)
222+
assert n % BLOCK_SIZE == 0, f"cols needs to be {BLOCK_SIZE} divisible"
223+
recovered_unswizzled_scale = unswizzle_sf(
224+
swizzle_sf(unswizzled_scale, m, n),
225+
m,
226+
n,
227+
)
228+
229+
# We don't expect the following since padding:
230+
# swizzle_sf(unswizzled_scale) == swizzled_scale
231+
ref_unswizzled_scale = unswizzle_sf(swizzled_scale, m, n)
232+
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
233+
assert_equal(recovered_unswizzled_scale, unswizzled_scale)
234+
assert_equal(ref_unswizzled_scale, unswizzled_scale)
235+
236+
135237
if __name__ == "__main__":
136238
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)