@@ -183,9 +183,17 @@ def to_nvfp4(
183183 if is_swizzled_scales :
184184 M , K = data_hp .shape [0 ], data_hp .shape [1 ]
185185 scale_shape = (M , K // block_size )
186- blockwise_scales = to_blocked (
187- blockwise_scales .view (scale_shape )
188- ).flatten ()
186+ # print(1, blockwise_scales.shape)
187+ blockwise_scales = blockwise_scales .view (scale_shape )
188+ # print(2, blockwise_scales.shape, blockwise_scales)
189+ blockwise_scales = to_blocked (blockwise_scales )
190+ # print(3, blockwise_scales.shape, blockwise_scales)
191+
192+ # match shape of data_hp
193+ scale_M = ceil_div (data_hp .shape [0 ], 128 ) * 128
194+ scale_K = ceil_div (data_hp .shape [1 ] // 16 , 4 ) * 4
195+ blockwise_scales = blockwise_scales .view (scale_M , scale_K )
196+ # print(4, blockwise_scales.shape, blockwise_scales)
189197
190198 return NVFP4Tensor (
191199 data_lp ,
@@ -220,6 +228,7 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
220228 data_unpacked = unpack_uint4 (data .contiguous ().view (torch .uint8 ))
221229 data_f32 = f4_unpacked_to_f32 (data_unpacked )
222230
231+ # next: debug scale shape here
223232 data_f32 = data_f32 .view (M , K // self ._block_size , self ._block_size )
224233 scale_e4m3_reshaped = self .get_hp_scales ().view (M , K // self ._block_size , 1 )
225234 data_scaled = data_f32 * scale_e4m3_reshaped .to (torch .float32 )
@@ -237,15 +246,17 @@ def get_hp_scales(self) -> torch.Tensor:
237246 torch.Tensor: Scales of the NVFP4Tensor
238247 """
239248 is_transposed = self .qdata .stride (0 ) < self .qdata .stride (1 )
249+ print ("is_transposed" , is_transposed )
240250 if is_transposed :
241251 M , K = self .shape [1 ], self .shape [0 ]
252+ scale_e4m3 = self ._scale_e4m3 .t ()
242253 else :
243254 M , K = self .shape [0 ], self .shape [1 ]
255+ scale_e4m3 = self ._scale_e4m3
244256
245257 if self ._is_swizzled_scales :
246- scale_e4m3 = from_blocked (self ._scale_e4m3 , M , K // self ._block_size )
247- else :
248- scale_e4m3 = self ._scale_e4m3
258+ # import pdb; pdb.set_trace()
259+ scale_e4m3 = from_blocked (scale_e4m3 , M , K // self ._block_size )
249260
250261 return (
251262 scale_e4m3 .to (self ._orig_dtype )
@@ -366,6 +377,7 @@ def nvfp4_slice(func, types, args, kwargs):
366377 raise ValueError ("Only support aten.slice with step=1" )
367378
368379 assert x .qdata .is_contiguous (), "Only support contiguous data for now"
380+ assert x ._scale_e4m3 .is_contiguous (), "Only support contiguous scale for now"
369381
370382 M , K = x .shape [0 ], x .shape [1 ]
371383
@@ -407,7 +419,7 @@ def nvfp4_slice(func, types, args, kwargs):
407419 else None
408420 )
409421
410- sliced_scale = aten .slice .Tensor (x ._scale_e4m3 , 0 , start_idx , end_idx , 1 )
422+ sliced_scale = aten .slice .Tensor (x ._scale_e4m3 , 0 , start , end , 1 )
411423 sliced_data = aten .slice .Tensor (x .qdata , 0 , start , end , step )
412424
413425 elif dim == 1 :
@@ -452,20 +464,24 @@ def nvfp4_slice(func, types, args, kwargs):
452464 # Full width - no slicing needed
453465 sliced_scale = x ._scale_e4m3
454466 else :
455- # Extract specific column blocks from each row block
456- # Each row block in swizzled format contains n_col_blocks chunks of (32, 16)
457- elements_per_row_block = n_col_blocks * elements_per_block
458-
459- # Build list of slices to extract
460- slices_to_extract = []
461- for row_block in range (n_row_blocks ):
462- row_start = row_block * elements_per_row_block
463- col_start = row_start + start_col_block * elements_per_block
464- col_end = row_start + end_col_block * elements_per_block
465- slices_to_extract .append (x ._scale_e4m3 [col_start :col_end ])
466-
467- # Concatenate all the slices
468- sliced_scale = torch .cat (slices_to_extract , dim = 0 )
467+ if False :
468+ # Extract specific column blocks from each row block
469+ # Each row block in swizzled format contains n_col_blocks chunks of (32, 16)
470+ elements_per_row_block = n_col_blocks * elements_per_block
471+
472+ # Build list of slices to extract
473+ slices_to_extract = []
474+ for row_block in range (n_row_blocks ):
475+ row_start = row_block * elements_per_row_block
476+ col_start = row_start + start_col_block * elements_per_block
477+ col_end = row_start + end_col_block * elements_per_block
478+ slices_to_extract .append (x ._scale_e4m3 [col_start :col_end ])
479+
480+ # Concatenate all the slices
481+ sliced_scale = torch .cat (slices_to_extract , dim = 0 )
482+ sliced_scale = aten .slice .Tensor (
483+ x ._scale_e4m3 , dim , start_scale_col , end_scale_col , step
484+ ).contiguous ()
469485
470486 # Slice the data tensor
471487 packed_start = None if start is None else start // 2
@@ -537,7 +553,7 @@ def nvfp4_t(func, types, args, kwargs):
537553 old = args [0 ]
538554 new = NVFP4Tensor (
539555 old .qdata .t (),
540- old ._scale_e4m3 ,
556+ old ._scale_e4m3 . t () ,
541557 old ._block_size ,
542558 old ._orig_dtype ,
543559 old ._per_tensor_scale ,
@@ -577,6 +593,8 @@ def _addmm_nvfp4_dispatch(
577593 """
578594 assert a .qdata .is_contiguous ()
579595 assert b .qdata .t ().is_contiguous ()
596+ assert a ._scale_e4m3 .is_contiguous ()
597+ assert b ._scale_e4m3 .t ().is_contiguous ()
580598 assert a ._block_size == 16 , f"NVFP4 requires block_size=16, got { a ._block_size } "
581599 assert b ._block_size == 16 , f"NVFP4 requires block_size=16, got { b ._block_size } "
582600
@@ -615,7 +633,7 @@ def _addmm_nvfp4_dispatch(
615633 a .qdata .view (torch .float4_e2m1fn_x2 ),
616634 b .qdata .view (torch .float4_e2m1fn_x2 ),
617635 a_scale_blocked .view (torch .float8_e4m3fn ),
618- b_scale_blocked .view (torch .float8_e4m3fn ),
636+ b_scale_blocked .t (). view (torch .float8_e4m3fn ),
619637 bias = None if should_add_bias_separately else bias ,
620638 out_dtype = a ._orig_dtype ,
621639 # scale_result=scale_result, # Not supported yet
0 commit comments