2424 tensor_size_fp4x2_to_hp ,
2525 tensor_size_hp_to_fp4x2 ,
2626)
27- from torchao .prototype .mx_formats .utils import from_blocked , to_blocked
27+ from torchao .prototype .mx_formats .utils import (
28+ from_blocked ,
29+ hp_data_dims_to_swizzled_scale_dims_nvfp4 ,
30+ to_blocked ,
31+ )
2832from torchao .quantization .quantize_ .common import (
2933 QuantizeTensorKwargs ,
3034)
@@ -170,6 +174,9 @@ def to_nvfp4(
170174 Returns:
171175 NVFP4Tensor: Quantized tensor in NVFP4 format
172176 """
177+ assert len (data_hp .shape ) == 2 , "unsupported"
178+ M , K = data_hp .shape [0 ], data_hp .shape [1 ]
179+
173180 if use_triton_kernel :
174181 assert is_swizzled_scales , "Triton kernel only supports swizzled scales"
175182 assert data_hp .shape [1 ] % 16 == 0 , (
@@ -181,12 +188,19 @@ def to_nvfp4(
181188 data_hp , block_size , per_tensor_scale
182189 )
183190 if is_swizzled_scales :
184- M , K = data_hp .shape [0 ], data_hp .shape [1 ]
185191 scale_shape = (M , K // block_size )
186192 blockwise_scales = to_blocked (
187193 blockwise_scales .view (scale_shape )
188194 ).flatten ()
189195
196+ if is_swizzled_scales :
197+ scale_M , scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4 (M , K )
198+ else :
199+ # a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1
200+ # scale element
201+ scale_M , scale_K = M , K // block_size
202+ blockwise_scales = blockwise_scales .view (scale_M , scale_K )
203+
190204 return NVFP4Tensor (
191205 data_lp ,
192206 blockwise_scales ,
@@ -239,13 +253,13 @@ def get_hp_scales(self) -> torch.Tensor:
239253 is_transposed = self .qdata .stride (0 ) < self .qdata .stride (1 )
240254 if is_transposed :
241255 M , K = self .shape [1 ], self .shape [0 ]
256+ scale_e4m3 = self ._scale_e4m3 .t ()
242257 else :
243258 M , K = self .shape [0 ], self .shape [1 ]
259+ scale_e4m3 = self ._scale_e4m3
244260
245261 if self ._is_swizzled_scales :
246- scale_e4m3 = from_blocked (self ._scale_e4m3 , M , K // self ._block_size )
247- else :
248- scale_e4m3 = self ._scale_e4m3
262+ scale_e4m3 = from_blocked (scale_e4m3 , M , K // self ._block_size )
249263
250264 return (
251265 scale_e4m3 .to (self ._orig_dtype )
@@ -369,6 +383,11 @@ def nvfp4_slice(func, types, args, kwargs):
369383
370384 M , K = x .shape [0 ], x .shape [1 ]
371385
386+ # The scale manipulations below assume a flattened scale. For now, we
387+ # flatten the scale, go through the calculations below, and then reshape
388+ # it back to the format which matches the shape of `qdata`.
389+ # TODO(future PR): update this
390+
372391 if x ._is_swizzled_scales :
373392 scale_rows = M
374393 scale_cols = K // x ._block_size
@@ -407,7 +426,9 @@ def nvfp4_slice(func, types, args, kwargs):
407426 else None
408427 )
409428
410- sliced_scale = aten .slice .Tensor (x ._scale_e4m3 , 0 , start_idx , end_idx , 1 )
429+ sliced_scale = aten .slice .Tensor (
430+ x ._scale_e4m3 .flatten (), 0 , start_idx , end_idx , 1
431+ )
411432 sliced_data = aten .slice .Tensor (x .qdata , 0 , start , end , step )
412433
413434 elif dim == 1 :
@@ -462,7 +483,7 @@ def nvfp4_slice(func, types, args, kwargs):
462483 row_start = row_block * elements_per_row_block
463484 col_start = row_start + start_col_block * elements_per_block
464485 col_end = row_start + end_col_block * elements_per_block
465- slices_to_extract .append (x ._scale_e4m3 [col_start :col_end ])
486+ slices_to_extract .append (x ._scale_e4m3 . flatten () [col_start :col_end ])
466487
467488 # Concatenate all the slices
468489 sliced_scale = torch .cat (slices_to_extract , dim = 0 )
@@ -515,6 +536,19 @@ def nvfp4_slice(func, types, args, kwargs):
515536
516537 sliced_scale = sliced_scale .flatten ()
517538
539+ # reshape at the end
540+ sliced_M = sliced_data .shape [0 ]
541+ # multiply by 2 to convert from bytes to num_elements
542+ sliced_K = sliced_data .shape [1 ] * 2
543+ if x ._is_swizzled_scales :
544+ scale_M , scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4 (sliced_M , sliced_K )
545+ else :
546+ # a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1
547+ # scale element
548+ scale_M = sliced_M
549+ scale_K = sliced_K // x ._block_size
550+ sliced_scale = sliced_scale .view (scale_M , scale_K )
551+
518552 # Create result tensor
519553 result = NVFP4Tensor (
520554 sliced_data ,
@@ -537,7 +571,7 @@ def nvfp4_t(func, types, args, kwargs):
537571 old = args [0 ]
538572 new = NVFP4Tensor (
539573 old .qdata .t (),
540- old ._scale_e4m3 ,
574+ old ._scale_e4m3 . t () ,
541575 old ._block_size ,
542576 old ._orig_dtype ,
543577 old ._per_tensor_scale ,
@@ -576,7 +610,9 @@ def _addmm_nvfp4_dispatch(
576610 The only difference is whether bias is None or not.
577611 """
578612 assert a .qdata .is_contiguous ()
613+ assert a ._scale_e4m3 .is_contiguous ()
579614 assert b .qdata .t ().is_contiguous ()
615+ assert b ._scale_e4m3 .t ().is_contiguous ()
580616 assert a ._block_size == 16 , f"NVFP4 requires block_size=16, got { a ._block_size } "
581617 assert b ._block_size == 16 , f"NVFP4 requires block_size=16, got { b ._block_size } "
582618
@@ -591,9 +627,9 @@ def _addmm_nvfp4_dispatch(
591627 a_scale_blocked = to_blocked (a_scale )
592628
593629 if b ._is_swizzled_scales :
594- b_scale_blocked = b ._scale_e4m3 # Already swizzled
630+ b_scale_blocked = b ._scale_e4m3 . t () # Already swizzled
595631 else :
596- b_scale = b ._scale_e4m3 .view (N , K // b ._block_size )
632+ b_scale = b ._scale_e4m3 .t (). view (N , K // b ._block_size )
597633 b_scale_blocked = to_blocked (b_scale )
598634
599635 # Merge double quant scales into 1 scale for Scale_In^D
0 commit comments