@@ -176,16 +176,31 @@ def to_nvfp4(
176176 f"Triton kernel requires K (dim 1) to be divisible by 16, got { data_hp .shape [1 ]} "
177177 )
178178 blockwise_scales , data_lp = triton_quantize_nvfp4 (data_hp , per_tensor_scale )
179+
180+ # TODO(before land): share code for scale shape manipulation in the two
181+ # if branches
182+ scale_M = ceil_div (data_hp .shape [0 ], 128 ) * 32
183+ scale_K = ceil_div (data_hp .shape [1 ] // 16 , 4 ) * 16
184+ blockwise_scales = blockwise_scales .view (scale_M , scale_K )
179185 else :
180186 blockwise_scales , data_lp = nvfp4_quantize (
181187 data_hp , block_size , per_tensor_scale
182188 )
183189 if is_swizzled_scales :
184190 M , K = data_hp .shape [0 ], data_hp .shape [1 ]
185191 scale_shape = (M , K // block_size )
186- blockwise_scales = to_blocked (
187- blockwise_scales .view (scale_shape )
188- ).flatten ()
192+ # print(1, blockwise_scales.shape)
193+ blockwise_scales = blockwise_scales .view (scale_shape )
194+ # print(2, blockwise_scales.shape, blockwise_scales)
195+ blockwise_scales = to_blocked (blockwise_scales )
196+ print (3 , blockwise_scales .shape , blockwise_scales )
197+
198+ # match shape of data_hp
199+ # in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile
200+ scale_M = ceil_div (data_hp .shape [0 ], 128 ) * 32
201+ scale_K = ceil_div (data_hp .shape [1 ] // 16 , 4 ) * 16
202+ blockwise_scales = blockwise_scales .view (scale_M , scale_K )
203+ print (4 , blockwise_scales .shape , blockwise_scales )
189204
190205 return NVFP4Tensor (
191206 data_lp ,
@@ -212,6 +227,7 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
212227 torch.Tensor: Dequantized tensor in the target dtype
213228 """
214229 is_transposed = self .qdata .stride (0 ) < self .qdata .stride (1 )
230+ print ('is_transposed' , is_transposed )
215231 if is_transposed :
216232 M , K = self .shape [1 ], self .shape [0 ]
217233 else :
@@ -220,8 +236,12 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
220236 data_unpacked = unpack_uint4 (data .contiguous ().view (torch .uint8 ))
221237 data_f32 = f4_unpacked_to_f32 (data_unpacked )
222238
239+ # next: debug scale shape here
223240 data_f32 = data_f32 .view (M , K // self ._block_size , self ._block_size )
224- scale_e4m3_reshaped = self .get_hp_scales ().view (M , K // self ._block_size , 1 )
241+ scales = self .get_hp_scales ()
242+ scales_tmp = scales .reshape (32 , - 1 )
243+ print ('scales' , scales_tmp .shape , scales_tmp [0 :8 ])
244+ scale_e4m3_reshaped = scales .view (M , K // self ._block_size , 1 )
225245 data_scaled = data_f32 * scale_e4m3_reshaped .to (torch .float32 )
226246 result = data_scaled .view (M , K ).to (target_dtype )
227247
@@ -237,15 +257,17 @@ def get_hp_scales(self) -> torch.Tensor:
237257 torch.Tensor: Scales of the NVFP4Tensor
238258 """
239259 is_transposed = self .qdata .stride (0 ) < self .qdata .stride (1 )
260+ print ("is_transposed" , is_transposed )
240261 if is_transposed :
241262 M , K = self .shape [1 ], self .shape [0 ]
263+ scale_e4m3 = self ._scale_e4m3 .t ()
242264 else :
243265 M , K = self .shape [0 ], self .shape [1 ]
266+ scale_e4m3 = self ._scale_e4m3
244267
245268 if self ._is_swizzled_scales :
246- scale_e4m3 = from_blocked (self ._scale_e4m3 , M , K // self ._block_size )
247- else :
248- scale_e4m3 = self ._scale_e4m3
269+ # import pdb; pdb.set_trace()
270+ scale_e4m3 = from_blocked (scale_e4m3 , M , K // self ._block_size )
249271
250272 return (
251273 scale_e4m3 .to (self ._orig_dtype )
@@ -366,6 +388,7 @@ def nvfp4_slice(func, types, args, kwargs):
366388 raise ValueError ("Only support aten.slice with step=1" )
367389
368390 assert x .qdata .is_contiguous (), "Only support contiguous data for now"
391+ assert x ._scale_e4m3 .is_contiguous (), "Only support contiguous scale for now"
369392
370393 M , K = x .shape [0 ], x .shape [1 ]
371394
@@ -376,6 +399,22 @@ def nvfp4_slice(func, types, args, kwargs):
376399 n_col_blocks = ceil_div (scale_cols , 4 )
377400 elements_per_block = 32 * 16 # 512 elements
378401
402+ #
403+ # See https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
404+ # for qdata vs scale layout. Here is a summary specific to nvfp4:
405+ #
406+ # 1. qdata tile shape is (128, 32) packed fp4, which is (128, 64) unpacked fp4
407+ # 2. scale tile shape is (32, 16)
408+ # 3. correspondence of qdata vs scale tiles is as follows, in a 2 by 2 tile example
409+ #
410+ # | tile_idx | qdata_rows | qdata_cols | scale_rows | scale_cols |
411+ # ----------------------------------------------------------------
412+ # | 0 | 0:127 | 0:31 | 0:31 | 0:15 |
413+ # | 1 | 128:255 | 0:31 | 32:63 | 0:15 |
414+ # | 2 | 0:127 | 32:63 | 0:31 | 16:31 |
415+ # | 3 | 128:255 | 32:63 | 32:63 | 16:31 |
416+ #
417+
379418 if dim == 0 :
380419 # Row slicing
381420 # Handle sys.maxsize (default slice end)
@@ -407,7 +446,9 @@ def nvfp4_slice(func, types, args, kwargs):
407446 else None
408447 )
409448
410- sliced_scale = aten .slice .Tensor (x ._scale_e4m3 , 0 , start_idx , end_idx , 1 )
449+ # TODO(before land): this is wrong, it works but need to express in terms of
450+ # properly laid out scale as in the comment block above
451+ sliced_scale = aten .slice .Tensor (x ._scale_e4m3 , 0 , start , end , 1 )
411452 sliced_data = aten .slice .Tensor (x .qdata , 0 , start , end , step )
412453
413454 elif dim == 1 :
@@ -452,20 +493,43 @@ def nvfp4_slice(func, types, args, kwargs):
452493 # Full width - no slicing needed
453494 sliced_scale = x ._scale_e4m3
454495 else :
455- # Extract specific column blocks from each row block
456- # Each row block in swizzled format contains n_col_blocks chunks of (32, 16)
457- elements_per_row_block = n_col_blocks * elements_per_block
458-
459- # Build list of slices to extract
460- slices_to_extract = []
461- for row_block in range (n_row_blocks ):
462- row_start = row_block * elements_per_row_block
463- col_start = row_start + start_col_block * elements_per_block
464- col_end = row_start + end_col_block * elements_per_block
465- slices_to_extract .append (x ._scale_e4m3 [col_start :col_end ])
466-
467- # Concatenate all the slices
468- sliced_scale = torch .cat (slices_to_extract , dim = 0 )
496+
497+ scale = x ._scale_e4m3
498+
499+ # reshape to expected shape
500+ # TODO(before land): do this when swizzling so we don't have to do it here
501+
502+ # TODO(before land): comment the mul by 2 here
503+ scale_rows = n_row_blocks * 16 * 2
504+ scale_cols = n_col_blocks * 16
505+ scale = scale .view (scale_rows , scale_cols )
506+
507+ # convert from hp_tensor row to scale row
508+ start_scale_col = 0 if start is None else (start // 128 * 16 )
509+ end_scale_col = scale_cols if end is None or end >= K else (end // 16 * 4 )
510+ # import pdb; pdb.set_trace()
511+
512+ if False :
513+ # Extract specific column blocks from each row block
514+ # Each row block in swizzled format contains n_col_blocks chunks of (32, 16)
515+ elements_per_row_block = n_col_blocks * elements_per_block
516+
517+ # Build list of slices to extract
518+ slices_to_extract = []
519+ for row_block in range (n_row_blocks ):
520+ row_start = row_block * elements_per_row_block
521+ col_start = row_start + start_col_block * elements_per_block
522+ col_end = row_start + end_col_block * elements_per_block
523+ slices_to_extract .append (x ._scale_e4m3 [col_start :col_end ])
524+
525+ # Concatenate all the slices
526+ sliced_scale = torch .cat (slices_to_extract , dim = 0 )
527+ # import pdb; pdb.set_trace()
528+ sliced_scale = aten .slice .Tensor (
529+ # x._scale_e4m3, dim, start_scale_col, end_scale_col, step
530+ scale , dim , start_scale_col , end_scale_col , step
531+ ).contiguous ()
532+ # import pdb; pdb.set_trace()
469533
470534 # Slice the data tensor
471535 packed_start = None if start is None else start // 2
@@ -537,7 +601,7 @@ def nvfp4_t(func, types, args, kwargs):
537601 old = args [0 ]
538602 new = NVFP4Tensor (
539603 old .qdata .t (),
540- old ._scale_e4m3 ,
604+ old ._scale_e4m3 . t () ,
541605 old ._block_size ,
542606 old ._orig_dtype ,
543607 old ._per_tensor_scale ,
@@ -577,6 +641,8 @@ def _addmm_nvfp4_dispatch(
577641 """
578642 assert a .qdata .is_contiguous ()
579643 assert b .qdata .t ().is_contiguous ()
644+ assert a ._scale_e4m3 .is_contiguous ()
645+ assert b ._scale_e4m3 .t ().is_contiguous ()
580646 assert a ._block_size == 16 , f"NVFP4 requires block_size=16, got { a ._block_size } "
581647 assert b ._block_size == 16 , f"NVFP4 requires block_size=16, got { b ._block_size } "
582648
@@ -615,7 +681,7 @@ def _addmm_nvfp4_dispatch(
615681 a .qdata .view (torch .float4_e2m1fn_x2 ),
616682 b .qdata .view (torch .float4_e2m1fn_x2 ),
617683 a_scale_blocked .view (torch .float8_e4m3fn ),
618- b_scale_blocked .view (torch .float8_e4m3fn ),
684+ b_scale_blocked .t (). view (torch .float8_e4m3fn ),
619685 bias = None if should_add_bias_separately else bias ,
620686 out_dtype = a ._orig_dtype ,
621687 # scale_result=scale_result, # Not supported yet
0 commit comments