Skip to content

Commit c70894d

Browse files
committed
[wip] make scale shape 2d and match qdata shape in NVFP4Tensor
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 6d5cad3 ghstack-comment-id: 3357503258 Pull-Request: #3108
1 parent 9368b28 commit c70894d

File tree

3 files changed

+120
-10
lines changed

3 files changed

+120
-10
lines changed

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
per_tensor_amax_to_scale,
2222
unpack_uint4,
2323
)
24+
from torchao.prototype.mx_formats.utils import ceil_div
2425
from torchao.quantization.utils import compute_error
2526
from torchao.testing.utils import skip_if_rocm
2627
from torchao.utils import (
@@ -525,3 +526,58 @@ def test_nvfp4_to_copy():
525526
assert x.act_quant_kwargs == y.act_quant_kwargs
526527
assert x.dtype == torch.float32
527528
assert y.dtype == torch.bfloat16
529+
530+
531+
@pytest.mark.parametrize("transpose", [False, True])
532+
@pytest.mark.parametrize("use_triton_kernel", [False, True])
533+
@pytest.mark.parametrize("is_swizzled_scales", [False, True])
534+
@pytest.mark.parametrize(
535+
"mk",
536+
(
537+
(128, 64),
538+
(128 + 16, 64),
539+
(128, 64 + 16),
540+
(128 + 16, 64 + 16),
541+
),
542+
)
543+
def test_scale_shape_matches_qdata(
544+
transpose, use_triton_kernel, is_swizzled_scales, mk
545+
):
546+
if use_triton_kernel and not is_swizzled_scales:
547+
pytest.skip("triton kernel requires swizzled scales")
548+
549+
M, K = mk
550+
551+
block_size = 16
552+
553+
x_hp = torch.randn(M, K, device="cuda")
554+
x = NVFP4Tensor.to_nvfp4(
555+
x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel
556+
)
557+
558+
m_dim, k_dim = 0, 1
559+
if transpose:
560+
x_hp = x_hp.t()
561+
x = x.t()
562+
m_dim, k_dim = 1, 0
563+
564+
orig_m = x_hp.shape[m_dim]
565+
expected_padded_m = orig_m
566+
if is_swizzled_scales:
567+
# in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
568+
expected_padded_m = ceil_div(orig_m, 128) * 32
569+
actual_padded_m = x._scale_e4m3.shape[m_dim]
570+
assert expected_padded_m == actual_padded_m, (
571+
f"incompatible padded shape for dim {m_dim}: {expected_padded_m=}, {actual_padded_m=}, {x.shape}, {x._scale_e4m3.shape}"
572+
)
573+
574+
orig_k = x_hp.shape[k_dim]
575+
expected_padded_k = orig_k // block_size
576+
if is_swizzled_scales:
577+
# in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
578+
expected_padded_k = ceil_div(orig_k // block_size, 4) * 16
579+
actual_padded_k = x._scale_e4m3.shape[k_dim]
580+
581+
assert expected_padded_k == actual_padded_k, (
582+
f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x._scale_e4m3.shape}"
583+
)

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
tensor_size_fp4x2_to_hp,
2525
tensor_size_hp_to_fp4x2,
2626
)
27-
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
27+
from torchao.prototype.mx_formats.utils import (
28+
from_blocked,
29+
hp_data_dims_to_swizzled_scale_dims_nvfp4,
30+
to_blocked,
31+
)
2832
from torchao.quantization.quantize_.common import (
2933
QuantizeTensorKwargs,
3034
)
@@ -170,6 +174,9 @@ def to_nvfp4(
170174
Returns:
171175
NVFP4Tensor: Quantized tensor in NVFP4 format
172176
"""
177+
assert len(data_hp.shape) == 2, "unsupported"
178+
M, K = data_hp.shape[0], data_hp.shape[1]
179+
173180
if use_triton_kernel:
174181
assert is_swizzled_scales, "Triton kernel only supports swizzled scales"
175182
assert data_hp.shape[1] % 16 == 0, (
@@ -181,12 +188,19 @@ def to_nvfp4(
181188
data_hp, block_size, per_tensor_scale
182189
)
183190
if is_swizzled_scales:
184-
M, K = data_hp.shape[0], data_hp.shape[1]
185191
scale_shape = (M, K // block_size)
186192
blockwise_scales = to_blocked(
187193
blockwise_scales.view(scale_shape)
188194
).flatten()
189195

196+
if is_swizzled_scales:
197+
scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4(M, K)
198+
else:
199+
# a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1
200+
# scale element
201+
scale_M, scale_K = M, K // block_size
202+
blockwise_scales = blockwise_scales.view(scale_M, scale_K)
203+
190204
return NVFP4Tensor(
191205
data_lp,
192206
blockwise_scales,
@@ -239,13 +253,13 @@ def get_hp_scales(self) -> torch.Tensor:
239253
is_transposed = self.qdata.stride(0) < self.qdata.stride(1)
240254
if is_transposed:
241255
M, K = self.shape[1], self.shape[0]
256+
scale_e4m3 = self._scale_e4m3.t()
242257
else:
243258
M, K = self.shape[0], self.shape[1]
259+
scale_e4m3 = self._scale_e4m3
244260

245261
if self._is_swizzled_scales:
246-
scale_e4m3 = from_blocked(self._scale_e4m3, M, K // self._block_size)
247-
else:
248-
scale_e4m3 = self._scale_e4m3
262+
scale_e4m3 = from_blocked(scale_e4m3, M, K // self._block_size)
249263

250264
return (
251265
scale_e4m3.to(self._orig_dtype)
@@ -369,6 +383,11 @@ def nvfp4_slice(func, types, args, kwargs):
369383

370384
M, K = x.shape[0], x.shape[1]
371385

386+
# The scale manipulations below assume a flattened scale. For now, we
387+
# flatten the scale, go through the calculations below, and then reshape
388+
# it back to the format which matches the shape of `qdata`.
389+
# TODO(future PR): update this
390+
372391
if x._is_swizzled_scales:
373392
scale_rows = M
374393
scale_cols = K // x._block_size
@@ -407,7 +426,9 @@ def nvfp4_slice(func, types, args, kwargs):
407426
else None
408427
)
409428

410-
sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start_idx, end_idx, 1)
429+
sliced_scale = aten.slice.Tensor(
430+
x._scale_e4m3.flatten(), 0, start_idx, end_idx, 1
431+
)
411432
sliced_data = aten.slice.Tensor(x.qdata, 0, start, end, step)
412433

413434
elif dim == 1:
@@ -462,7 +483,7 @@ def nvfp4_slice(func, types, args, kwargs):
462483
row_start = row_block * elements_per_row_block
463484
col_start = row_start + start_col_block * elements_per_block
464485
col_end = row_start + end_col_block * elements_per_block
465-
slices_to_extract.append(x._scale_e4m3[col_start:col_end])
486+
slices_to_extract.append(x._scale_e4m3.flatten()[col_start:col_end])
466487

467488
# Concatenate all the slices
468489
sliced_scale = torch.cat(slices_to_extract, dim=0)
@@ -515,6 +536,19 @@ def nvfp4_slice(func, types, args, kwargs):
515536

516537
sliced_scale = sliced_scale.flatten()
517538

539+
# reshape at the end
540+
sliced_M = sliced_data.shape[0]
541+
# multiply by 2 to convert from bytes to num_elements
542+
sliced_K = sliced_data.shape[1] * 2
543+
if x._is_swizzled_scales:
544+
scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4(sliced_M, sliced_K)
545+
else:
546+
# a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1
547+
# scale element
548+
scale_M = sliced_M
549+
scale_K = sliced_K // x._block_size
550+
sliced_scale = sliced_scale.view(scale_M, scale_K)
551+
518552
# Create result tensor
519553
result = NVFP4Tensor(
520554
sliced_data,
@@ -537,7 +571,7 @@ def nvfp4_t(func, types, args, kwargs):
537571
old = args[0]
538572
new = NVFP4Tensor(
539573
old.qdata.t(),
540-
old._scale_e4m3,
574+
old._scale_e4m3.t(),
541575
old._block_size,
542576
old._orig_dtype,
543577
old._per_tensor_scale,
@@ -576,7 +610,9 @@ def _addmm_nvfp4_dispatch(
576610
The only difference is whether bias is None or not.
577611
"""
578612
assert a.qdata.is_contiguous()
613+
assert a._scale_e4m3.is_contiguous()
579614
assert b.qdata.t().is_contiguous()
615+
assert b._scale_e4m3.t().is_contiguous()
580616
assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}"
581617
assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}"
582618

@@ -591,9 +627,9 @@ def _addmm_nvfp4_dispatch(
591627
a_scale_blocked = to_blocked(a_scale)
592628

593629
if b._is_swizzled_scales:
594-
b_scale_blocked = b._scale_e4m3 # Already swizzled
630+
b_scale_blocked = b._scale_e4m3.t() # Already swizzled
595631
else:
596-
b_scale = b._scale_e4m3.view(N, K // b._block_size)
632+
b_scale = b._scale_e4m3.t().view(N, K // b._block_size)
597633
b_scale_blocked = to_blocked(b_scale)
598634

599635
# Merge double quant scales into 1 scale for Scale_In^D

torchao/prototype/mx_formats/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Tuple
8+
79
import torch
810
from torch.distributed._tensor import DTensor
911

@@ -99,6 +101,22 @@ def from_blocked(
99101
return padded[:original_rows, :original_cols]
100102

101103

104+
def hp_data_dims_to_swizzled_scale_dims_nvfp4(
105+
hp_data_M,
106+
hp_data_K,
107+
) -> Tuple[int, int]:
108+
"""
109+
Given the `M` and `K` dimensions of a high precision contiguous tensor,
110+
returns a 2d tuple of the dims of the swizzled nvfp4 scale corresponding to
111+
that tensor.
112+
"""
113+
# a 128x64 unpacked or 128x32 packed qdata tile corresponds
114+
# to a swizzled 32x16 scale tile
115+
scale_M = ceil_div(hp_data_M, 128) * 32
116+
scale_K = ceil_div(hp_data_K, 64) * 16
117+
return scale_M, scale_K
118+
119+
102120
def _to_blocked_single(scales: Tensor) -> Tensor:
103121
"""Assume that we have a 128x4 block of scales in K Major order
104122

0 commit comments

Comments
 (0)