Skip to content

Commit 16e7da8

Browse files
committed
[wip] make nvfp4 scale shape match qdata
Summary: does not work yet, need to fix tests we'll need this to stitch MoE weights in vLLM Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: a5ec3b1 ghstack-comment-id: 3348895157 Pull-Request: #3094
1 parent 9368b28 commit 16e7da8

File tree

3 files changed

+207
-36
lines changed

3 files changed

+207
-36
lines changed

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 114 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
per_tensor_amax_to_scale,
2222
unpack_uint4,
2323
)
24+
from torchao.prototype.mx_formats.utils import ceil_div
2425
from torchao.quantization.utils import compute_error
2526
from torchao.testing.utils import skip_if_rocm
2627
from torchao.utils import (
@@ -127,18 +128,18 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):
127128
"slice_dim,slice_spec",
128129
[
129130
# Row slicing - must align with 128-row boundaries
130-
pytest.param(0, slice(0, 128), id="slice_rows[0:128]"),
131-
pytest.param(0, slice(128, 256), id="slice_rows[128:256]"),
131+
# pytest.param(0, slice(0, 128), id="slice_rows[0:128]"),
132+
# pytest.param(0, slice(128, 256), id="slice_rows[128:256]"),
132133
# Column slicing - must align with 64-column boundaries (4 scale columns * 16 block_size)
133134
pytest.param(1, slice(0, 64), id="slice_cols[0:64]"),
134-
pytest.param(1, slice(64, 128), id="slice_cols[64:128]"),
135-
pytest.param(1, slice(0, 128), id="slice_cols[0:128]_full_width"),
135+
# pytest.param(1, slice(64, 128), id="slice_cols[64:128]"),
136+
# pytest.param(1, slice(0, 128), id="slice_cols[0:128]_full_width"),
136137
# Test tensor parallelism patterns (half splits)
137-
pytest.param(1, slice(0, 2048), id="slice_cols[0:2048]_tp_first_half"),
138-
pytest.param(1, slice(2048, 4096), id="slice_cols[2048:4096]_tp_second_half"),
138+
# pytest.param(1, slice(0, 2048), id="slice_cols[0:2048]_tp_first_half"),
139+
# pytest.param(1, slice(2048, 4096), id="slice_cols[2048:4096]_tp_second_half"),
139140
# Test quarter splits
140-
pytest.param(1, slice(0, 1024), id="slice_cols[0:1024]_quarter"),
141-
pytest.param(1, slice(1024, 2048), id="slice_cols[1024:2048]_quarter"),
141+
# pytest.param(1, slice(0, 1024), id="slice_cols[0:1024]_quarter"),
142+
# pytest.param(1, slice(1024, 2048), id="slice_cols[1024:2048]_quarter"),
142143
],
143144
)
144145
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -157,21 +158,54 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
157158
M, K = 256, 4096
158159
else:
159160
# For column slicing, need multiples of 64 columns for alignment
160-
M, K = 128, 4096
161+
# M, K = 128, 4096
162+
M, K = 128, 64 * 2
161163

162164
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
163165

164166
tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True)
167+
# tensor.to_dtype(torch.bfloat16)
165168
assert tensor._is_swizzled_scales == True
166169

170+
print(
171+
"before",
172+
tensor.shape,
173+
tensor.qdata.shape,
174+
tensor._scale_e4m3.shape,
175+
)
176+
# print(tensor._scale_e4m3[0:128, 0:4])
167177
if slice_dim == 0:
168178
sliced_tensor = tensor[slice_spec, :]
169179
else:
170180
sliced_tensor = tensor[:, slice_spec]
181+
print(
182+
"after",
183+
sliced_tensor.shape,
184+
sliced_tensor.qdata.shape,
185+
sliced_tensor._scale_e4m3.shape,
186+
)
187+
# print(sliced_tensor._scale_e4m3[0:128, 0:4])
188+
# print(sliced_tensor.qdata.float() - tensor.qdata[0:128, 0:32].float())
189+
# print(sliced_tensor._scale_e4m3.float() - tensor._scale_e4m3[0:128, 0:4].float())
171190

172191
# Verify sliced tensor maintains swizzled state
173192
assert sliced_tensor._is_swizzled_scales == True
174193

194+
# this matches sliced_reconstructed, but not original_reconstructed[:, slice_spec]
195+
if False:
196+
sliced_manually = NVFP4Tensor(
197+
tensor.qdata[:, 0:32],
198+
tensor._scale_e4m3[:, 0:4].contiguous(),
199+
tensor._block_size,
200+
tensor._orig_dtype,
201+
tensor._per_tensor_scale,
202+
tensor._act_per_tensor_scale,
203+
tensor._is_swizzled_scales,
204+
tensor.use_triton_kernel,
205+
tensor.act_quant_kwargs,
206+
)
207+
import pdb; pdb.set_trace()
208+
175209
# Verify sliced tensor can be dequantized
176210
sliced_reconstructed = sliced_tensor.to_dtype(torch.bfloat16)
177211

@@ -181,6 +215,11 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
181215
expected = original_reconstructed[slice_spec, :]
182216
else:
183217
expected = original_reconstructed[:, slice_spec]
218+
print('e', expected)
219+
print('s', sliced_reconstructed)
220+
print('e - s', expected - sliced_reconstructed)
221+
print(1, expected.abs().sum())
222+
print(2, sliced_reconstructed.abs().sum())
184223

185224
torch.testing.assert_close(sliced_reconstructed, expected, atol=1e-6, rtol=1e-6)
186225

@@ -421,7 +460,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
421460
@pytest.mark.parametrize("compile", [False])
422461
@pytest.mark.parametrize("bias", [True, False])
423462
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
424-
@pytest.mark.parametrize("use_triton_kernel", [True, False])
463+
# @pytest.mark.parametrize("use_triton_kernel", [True, False])
464+
@pytest.mark.parametrize("use_triton_kernel", [False])
425465
@pytest.mark.parametrize(
426466
"shapes",
427467
[
@@ -525,3 +565,67 @@ def test_nvfp4_to_copy():
525565
assert x.act_quant_kwargs == y.act_quant_kwargs
526566
assert x.dtype == torch.float32
527567
assert y.dtype == torch.bfloat16
568+
569+
570+
@pytest.mark.parametrize("transpose", [False, True])
571+
# @pytest.mark.parametrize("transpose", [True])
572+
# @pytest.mark.parametrize("transpose", [False])
573+
@pytest.mark.parametrize("use_triton_kernel", [False, True])
574+
# @pytest.mark.parametrize("use_triton_kernel", [False])
575+
# @pytest.mark.parametrize("use_triton_kernel", [True])
576+
@pytest.mark.parametrize("is_swizzled_scales", [False, True])
577+
# @pytest.mark.parametrize("is_swizzled_scales", [True])
578+
@pytest.mark.parametrize(
579+
"mk",
580+
(
581+
(128, 64),
582+
(128 + 16, 64),
583+
(128, 64 + 16),
584+
(128 + 16, 64 + 16),
585+
),
586+
)
587+
# @pytest.mark.parametrize("mk", ((128 + 16, 64),))
588+
def test_scale_shape_matches_qdata(
589+
transpose, use_triton_kernel, is_swizzled_scales, mk
590+
):
591+
if use_triton_kernel and not is_swizzled_scales:
592+
pytest.skip("triton kernel requires swizzled scales")
593+
594+
M, K = mk
595+
596+
block_size = 16
597+
598+
# TODO(this PR): test larger tensors that don't exactly map to (128, 64) tiles,
599+
# to test the padding logic
600+
# context: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
601+
x_hp = torch.randn(M, K, device="cuda")
602+
x = NVFP4Tensor.to_nvfp4(
603+
x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel
604+
)
605+
606+
m_dim, k_dim = 0, 1
607+
if transpose:
608+
x_hp = x_hp.t()
609+
x = x.t()
610+
m_dim, k_dim = 1, 0
611+
612+
orig_m = x_hp.shape[m_dim]
613+
expected_padded_m = orig_m
614+
if is_swizzled_scales:
615+
# in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
616+
expected_padded_m = ceil_div(orig_m, 128) * 32
617+
actual_padded_m = x._scale_e4m3.shape[m_dim]
618+
assert expected_padded_m == actual_padded_m, (
619+
f"incompatible padded shape for dim {m_dim}: {x.shape} and {x._scale_e4m3.shape}"
620+
)
621+
622+
orig_k = x_hp.shape[k_dim]
623+
expected_padded_k = orig_k // block_size
624+
if is_swizzled_scales:
625+
# in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
626+
expected_padded_k = ceil_div(orig_k // block_size, 4) * 16
627+
actual_padded_k = x._scale_e4m3.shape[k_dim]
628+
629+
assert expected_padded_k == actual_padded_k, (
630+
f"incompatible padded shape for dim {k_dim}: {x.shape} and {x._scale_e4m3.shape}"
631+
)

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 90 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -176,16 +176,31 @@ def to_nvfp4(
176176
f"Triton kernel requires K (dim 1) to be divisible by 16, got {data_hp.shape[1]}"
177177
)
178178
blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale)
179+
180+
# TODO(before land): share code for scale shape manipulation in the two
181+
# if branches
182+
scale_M = ceil_div(data_hp.shape[0], 128) * 32
183+
scale_K = ceil_div(data_hp.shape[1] // 16, 4) * 16
184+
blockwise_scales = blockwise_scales.view(scale_M, scale_K)
179185
else:
180186
blockwise_scales, data_lp = nvfp4_quantize(
181187
data_hp, block_size, per_tensor_scale
182188
)
183189
if is_swizzled_scales:
184190
M, K = data_hp.shape[0], data_hp.shape[1]
185191
scale_shape = (M, K // block_size)
186-
blockwise_scales = to_blocked(
187-
blockwise_scales.view(scale_shape)
188-
).flatten()
192+
# print(1, blockwise_scales.shape)
193+
blockwise_scales = blockwise_scales.view(scale_shape)
194+
# print(2, blockwise_scales.shape, blockwise_scales)
195+
blockwise_scales = to_blocked(blockwise_scales)
196+
print(3, blockwise_scales.shape, blockwise_scales)
197+
198+
# match shape of data_hp
199+
# in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
200+
scale_M = ceil_div(data_hp.shape[0], 128) * 32
201+
scale_K = ceil_div(data_hp.shape[1] // 16, 4) * 16
202+
blockwise_scales = blockwise_scales.view(scale_M, scale_K)
203+
print(4, blockwise_scales.shape, blockwise_scales)
189204

190205
return NVFP4Tensor(
191206
data_lp,
@@ -212,6 +227,7 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
212227
torch.Tensor: Dequantized tensor in the target dtype
213228
"""
214229
is_transposed = self.qdata.stride(0) < self.qdata.stride(1)
230+
print('is_transposed', is_transposed)
215231
if is_transposed:
216232
M, K = self.shape[1], self.shape[0]
217233
else:
@@ -220,8 +236,12 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
220236
data_unpacked = unpack_uint4(data.contiguous().view(torch.uint8))
221237
data_f32 = f4_unpacked_to_f32(data_unpacked)
222238

239+
# next: debug scale shape here
223240
data_f32 = data_f32.view(M, K // self._block_size, self._block_size)
224-
scale_e4m3_reshaped = self.get_hp_scales().view(M, K // self._block_size, 1)
241+
scales = self.get_hp_scales()
242+
scales_tmp = scales.reshape(32, -1)
243+
print('scales', scales_tmp.shape, scales_tmp[0:8])
244+
scale_e4m3_reshaped = scales.view(M, K // self._block_size, 1)
225245
data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32)
226246
result = data_scaled.view(M, K).to(target_dtype)
227247

@@ -237,15 +257,17 @@ def get_hp_scales(self) -> torch.Tensor:
237257
torch.Tensor: Scales of the NVFP4Tensor
238258
"""
239259
is_transposed = self.qdata.stride(0) < self.qdata.stride(1)
260+
print("is_transposed", is_transposed)
240261
if is_transposed:
241262
M, K = self.shape[1], self.shape[0]
263+
scale_e4m3 = self._scale_e4m3.t()
242264
else:
243265
M, K = self.shape[0], self.shape[1]
266+
scale_e4m3 = self._scale_e4m3
244267

245268
if self._is_swizzled_scales:
246-
scale_e4m3 = from_blocked(self._scale_e4m3, M, K // self._block_size)
247-
else:
248-
scale_e4m3 = self._scale_e4m3
269+
# import pdb; pdb.set_trace()
270+
scale_e4m3 = from_blocked(scale_e4m3, M, K // self._block_size)
249271

250272
return (
251273
scale_e4m3.to(self._orig_dtype)
@@ -366,6 +388,7 @@ def nvfp4_slice(func, types, args, kwargs):
366388
raise ValueError("Only support aten.slice with step=1")
367389

368390
assert x.qdata.is_contiguous(), "Only support contiguous data for now"
391+
assert x._scale_e4m3.is_contiguous(), "Only support contiguous scale for now"
369392

370393
M, K = x.shape[0], x.shape[1]
371394

@@ -376,6 +399,22 @@ def nvfp4_slice(func, types, args, kwargs):
376399
n_col_blocks = ceil_div(scale_cols, 4)
377400
elements_per_block = 32 * 16 # 512 elements
378401

402+
#
403+
# See https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
404+
# for qdata vs scale layout. Here is a summary specific to nvfp4:
405+
#
406+
# 1. qdata tile shape is (128, 32) packed fp4, which is (128, 64) unpacked fp4
407+
# 2. scale tile shape is (32, 16)
408+
# 3. correspondence of qdata vs scale tiles is as follows, in a 2 by 2 tile example
409+
#
410+
# | tile_idx | qdata_rows | qdata_cols | scale_rows | scale_cols |
411+
# ----------------------------------------------------------------
412+
# | 0 | 0:127 | 0:31 | 0:31 | 0:15 |
413+
# | 1 | 128:255 | 0:31 | 32:63 | 0:15 |
414+
# | 2 | 0:127 | 32:63 | 0:31 | 16:31 |
415+
# | 3 | 128:255 | 32:63 | 32:63 | 16:31 |
416+
#
417+
379418
if dim == 0:
380419
# Row slicing
381420
# Handle sys.maxsize (default slice end)
@@ -407,7 +446,9 @@ def nvfp4_slice(func, types, args, kwargs):
407446
else None
408447
)
409448

410-
sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start_idx, end_idx, 1)
449+
# TODO(before land): this is wrong, it works but need to express in terms of
450+
# properly laid out scale as in the comment block above
451+
sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start, end, 1)
411452
sliced_data = aten.slice.Tensor(x.qdata, 0, start, end, step)
412453

413454
elif dim == 1:
@@ -452,20 +493,43 @@ def nvfp4_slice(func, types, args, kwargs):
452493
# Full width - no slicing needed
453494
sliced_scale = x._scale_e4m3
454495
else:
455-
# Extract specific column blocks from each row block
456-
# Each row block in swizzled format contains n_col_blocks chunks of (32, 16)
457-
elements_per_row_block = n_col_blocks * elements_per_block
458-
459-
# Build list of slices to extract
460-
slices_to_extract = []
461-
for row_block in range(n_row_blocks):
462-
row_start = row_block * elements_per_row_block
463-
col_start = row_start + start_col_block * elements_per_block
464-
col_end = row_start + end_col_block * elements_per_block
465-
slices_to_extract.append(x._scale_e4m3[col_start:col_end])
466-
467-
# Concatenate all the slices
468-
sliced_scale = torch.cat(slices_to_extract, dim=0)
496+
497+
scale = x._scale_e4m3
498+
499+
# reshape to expected shape
500+
# TODO(before land): do this when swizzling so we don't have to do it here
501+
502+
# TODO(before land): comment the mul by 2 here
503+
scale_rows = n_row_blocks * 16 * 2
504+
scale_cols = n_col_blocks * 16
505+
scale = scale.view(scale_rows, scale_cols)
506+
507+
# convert from hp_tensor row to scale row
508+
start_scale_col = 0 if start is None else (start // 128 * 16)
509+
end_scale_col = scale_cols if end is None or end >= K else (end // 16 * 4)
510+
# import pdb; pdb.set_trace()
511+
512+
if False:
513+
# Extract specific column blocks from each row block
514+
# Each row block in swizzled format contains n_col_blocks chunks of (32, 16)
515+
elements_per_row_block = n_col_blocks * elements_per_block
516+
517+
# Build list of slices to extract
518+
slices_to_extract = []
519+
for row_block in range(n_row_blocks):
520+
row_start = row_block * elements_per_row_block
521+
col_start = row_start + start_col_block * elements_per_block
522+
col_end = row_start + end_col_block * elements_per_block
523+
slices_to_extract.append(x._scale_e4m3[col_start:col_end])
524+
525+
# Concatenate all the slices
526+
sliced_scale = torch.cat(slices_to_extract, dim=0)
527+
# import pdb; pdb.set_trace()
528+
sliced_scale = aten.slice.Tensor(
529+
# x._scale_e4m3, dim, start_scale_col, end_scale_col, step
530+
scale, dim, start_scale_col, end_scale_col, step
531+
).contiguous()
532+
# import pdb; pdb.set_trace()
469533

470534
# Slice the data tensor
471535
packed_start = None if start is None else start // 2
@@ -537,7 +601,7 @@ def nvfp4_t(func, types, args, kwargs):
537601
old = args[0]
538602
new = NVFP4Tensor(
539603
old.qdata.t(),
540-
old._scale_e4m3,
604+
old._scale_e4m3.t(),
541605
old._block_size,
542606
old._orig_dtype,
543607
old._per_tensor_scale,
@@ -577,6 +641,8 @@ def _addmm_nvfp4_dispatch(
577641
"""
578642
assert a.qdata.is_contiguous()
579643
assert b.qdata.t().is_contiguous()
644+
assert a._scale_e4m3.is_contiguous()
645+
assert b._scale_e4m3.t().is_contiguous()
580646
assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}"
581647
assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}"
582648

@@ -615,7 +681,7 @@ def _addmm_nvfp4_dispatch(
615681
a.qdata.view(torch.float4_e2m1fn_x2),
616682
b.qdata.view(torch.float4_e2m1fn_x2),
617683
a_scale_blocked.view(torch.float8_e4m3fn),
618-
b_scale_blocked.view(torch.float8_e4m3fn),
684+
b_scale_blocked.t().view(torch.float8_e4m3fn),
619685
bias=None if should_add_bias_separately else bias,
620686
out_dtype=a._orig_dtype,
621687
# scale_result=scale_result, # Not supported yet

0 commit comments

Comments
 (0)