44# This source code is licensed under the BSD 3-Clause license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import math
78import sys
89from dataclasses import dataclass
910from enum import Enum
@@ -112,7 +113,7 @@ def __new__(
112113
113114 new_size = tensor_size_fp4x2_to_hp (
114115 new_size ,
115- qdata .stride (0 ) > qdata .stride (1 ),
116+ qdata .stride (- 2 ) > qdata .stride (- 1 ),
116117 )
117118
118119 self = torch .Tensor ._make_wrapper_subclass (
@@ -174,21 +175,21 @@ def to_nvfp4(
174175 Returns:
175176 NVFP4Tensor: Quantized tensor in NVFP4 format
176177 """
177- assert len (data_hp .shape ) == 2 , "unsupported"
178- M , K = data_hp .shape [0 ], data_hp .shape [1 ]
178+ assert len (data_hp .shape ) in ( 2 , 3 ) , "unsupported"
179+ leading_dims , M , K = data_hp .shape [: - 2 ], data_hp .shape [- 2 ], data_hp . shape [ - 1 ]
179180
180181 if use_triton_kernel :
181182 assert is_swizzled_scales , "Triton kernel only supports swizzled scales"
182- assert data_hp . shape [ 1 ] % 16 == 0 , (
183- f"Triton kernel requires K (dim 1) to be divisible by 16, got { data_hp . shape [ 1 ] } "
183+ assert K % 16 == 0 , (
184+ f"Triton kernel requires K (dim - 1) to be divisible by 16, got { K } "
184185 )
185186 blockwise_scales , data_lp = triton_quantize_nvfp4 (data_hp , per_tensor_scale )
186187 else :
187188 blockwise_scales , data_lp = nvfp4_quantize (
188189 data_hp , block_size , per_tensor_scale
189190 )
190191 if is_swizzled_scales :
191- scale_shape = (M , K // block_size )
192+ scale_shape = (math . prod ( leading_dims ) * M , K // block_size )
192193 blockwise_scales = to_blocked (
193194 blockwise_scales .view (scale_shape )
194195 ).flatten ()
@@ -199,7 +200,7 @@ def to_nvfp4(
199200 # a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1
200201 # scale element
201202 scale_M , scale_K = M , K // block_size
202- blockwise_scales = blockwise_scales .view (scale_M , scale_K )
203+ blockwise_scales = blockwise_scales .view (* leading_dims , scale_M , scale_K )
203204
204205 return NVFP4Tensor (
205206 data_lp ,
@@ -225,22 +226,26 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
225226 Returns:
226227 torch.Tensor: Dequantized tensor in the target dtype
227228 """
228- is_transposed = self .qdata .stride (0 ) < self .qdata .stride (1 )
229+ is_transposed = self .qdata .stride (- 2 ) < self .qdata .stride (- 1 )
229230 if is_transposed :
230- M , K = self .shape [1 ], self .shape [0 ]
231+ leading_dims , M , K = self .shape [: - 2 ], self . shape [ - 1 ], self .shape [- 2 ]
231232 else :
232- M , K = self .shape [0 ], self .shape [1 ]
233- data = self .qdata .t ( ) if is_transposed else self .qdata
233+ leading_dims , M , K = self .shape [: - 2 ], self .shape [- 2 ], self . shape [ - 1 ]
234+ data = self .qdata .transpose ( - 2 , - 1 ) if is_transposed else self .qdata
234235 data_unpacked = unpack_uint4 (data .contiguous ().view (torch .uint8 ))
235236 data_f32 = f4_unpacked_to_f32 (data_unpacked )
236237
237- data_f32 = data_f32 .view (M , K // self ._block_size , self ._block_size )
238- scale_e4m3_reshaped = self .get_hp_scales ().view (M , K // self ._block_size , 1 )
238+ data_f32 = data_f32 .view (
239+ * leading_dims , M , K // self ._block_size , self ._block_size
240+ )
241+ scale_e4m3_reshaped = self .get_hp_scales ().view (
242+ * leading_dims , M , K // self ._block_size , 1
243+ )
239244 data_scaled = data_f32 * scale_e4m3_reshaped .to (torch .float32 )
240- result = data_scaled .view (M , K ).to (target_dtype )
245+ result = data_scaled .view (* leading_dims , M , K ).to (target_dtype )
241246
242247 if is_transposed :
243- result = result .t ( )
248+ result = result .transpose ( - 2 , - 1 )
244249
245250 return result
246251
@@ -250,16 +255,18 @@ def get_hp_scales(self) -> torch.Tensor:
250255 Returns:
251256 torch.Tensor: Scales of the NVFP4Tensor
252257 """
253- is_transposed = self .qdata .stride (0 ) < self .qdata .stride (1 )
258+ is_transposed = self .qdata .stride (- 2 ) < self .qdata .stride (- 1 )
254259 if is_transposed :
255- M , K = self .shape [1 ], self .shape [0 ]
256- scale_e4m3 = self ._scale_e4m3 .t ( )
260+ leading_dims , M , K = self .shape [: - 2 ], self . shape [ - 1 ], self .shape [- 2 ]
261+ scale_e4m3 = self ._scale_e4m3 .transpose ( - 2 , - 1 )
257262 else :
258- M , K = self .shape [0 ], self .shape [1 ]
263+ leading_dims , M , K = self .shape [: - 2 ], self .shape [- 2 ], self . shape [ - 1 ]
259264 scale_e4m3 = self ._scale_e4m3
260265
261266 if self ._is_swizzled_scales :
262- scale_e4m3 = from_blocked (scale_e4m3 , M , K // self ._block_size )
267+ scale_e4m3 = from_blocked (
268+ scale_e4m3 , math .prod (leading_dims ) * M , K // self ._block_size
269+ )
263270
264271 return (
265272 scale_e4m3 .to (self ._orig_dtype )
@@ -380,6 +387,9 @@ def nvfp4_slice(func, types, args, kwargs):
380387 raise ValueError ("Only support aten.slice with step=1" )
381388
382389 assert x .qdata .is_contiguous (), "Only support contiguous data for now"
390+ assert len (x .shape ) == 2 , (
391+ f"only rank 2 is supported for slice, got rank { len (x .shape )} "
392+ )
383393
384394 M , K = x .shape [0 ], x .shape [1 ]
385395
@@ -583,6 +593,26 @@ def nvfp4_t(func, types, args, kwargs):
583593 return new
584594
585595
596+ @implements ([aten .transpose .int ])
597+ def nvfp4_transpose (func , types , args , kwargs ):
598+ old , dim0 , dim1 = args
599+ assert dim0 == - 2 and dim1 == - 1 , f"transpose unsupported for { dim0 = } { dim1 = } "
600+ new_qdata = func (old .qdata , dim0 , dim1 , ** kwargs )
601+ new_scale = func (old ._scale_e4m3 , dim0 , dim1 , ** kwargs )
602+ new = NVFP4Tensor (
603+ new_qdata ,
604+ new_scale ,
605+ old ._block_size ,
606+ old ._orig_dtype ,
607+ old ._per_tensor_scale ,
608+ old ._act_per_tensor_scale ,
609+ old ._is_swizzled_scales ,
610+ old .use_triton_kernel ,
611+ old .act_quant_kwargs ,
612+ )
613+ return new
614+
615+
586616@implements ([aten .view .default ])
587617def nvfp4_view_op (func , types , args , kwargs ):
588618 data = args [0 ].qdata
0 commit comments