@@ -628,10 +628,12 @@ def cutlass_scaled_sparse_mm(
628628 return out
629629
630630
631- def cutlass_fp4_gemm (a : torch .Tensor , b : torch .Tensor , input_sf : torch .Tensor ,
632- weight_sf : torch .dtype , global_sf : torch .dtype ,
633- workspace : torch .dtype , workspace_bytes : int ,
634- out_dtype : torch .dtype ) -> torch .Tensor :
631+ # nvfp4
632+ def cutlass_scaled_fp4_mm (a : torch .Tensor , b : torch .Tensor ,
633+ block_scale_a : torch .Tensor ,
634+ block_scale_b : torch .dtype , gscale : torch .dtype ,
635+ workspace : torch .dtype , workspace_bytes : int ,
636+ out_dtype : torch .dtype ) -> torch .Tensor :
635637 """
636638 Gemm when a and b have nvfp4 datatype(currently represented as a byte),
637639 along with their respective block scales and a global scaling factor.
@@ -643,19 +645,80 @@ def cutlass_fp4_gemm(a: torch.Tensor, b: torch.Tensor, input_sf: torch.Tensor,
643645 n = b .shape [1 ]
644646 workspace_bytes = workspace .nbytes
645647 out = torch .empty ((m , n ), dtype = out_dtype , device = a .device )
646- torch .ops ._C .cutlass_fp4_gemm (out , a , b , input_sf , weight_sf , global_sf ,
647- workspace , workspace_bytes )
648+ torch .ops ._C .cutlass_scaled_fp4_mm (out , a , b , block_scale_a , block_scale_b ,
649+ gscale , workspace , workspace_bytes )
648650 return out
649651
650652
651- def quantize_to_fp4 (input : torch .Tensor , input_sf : torch .Tensor ,
652- output_sf : torch .Tensor ) -> torch .Tensor :
653- assert (input is torch .bfloat16 or input is torch .float16 )
654- m = input .shape [0 ]
655- n = input .shape [1 ]
656- output = torch .empty ((m , n // 2 ), dtype = torch .uint8 , device = input .device )
657- torch .ops ._C .quantize_fp4 (output , input , input_sf , output_sf , False )
658- return output , output_sf
653+ def pad_up_fn (x , y ):
654+ """Pads up x to the nearest multiple of y."""
655+ return ((x + y - 1 ) // y ) * y
656+
657+
658+ def scaled_fp4_quant (
659+ input : torch .Tensor ,
660+ global_scale : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
661+ """
662+ Quantizes the input in FP32/BF16/FP16 to NVFP4 Precision
663+ The function returns quantized fp4 tensor and its
664+ corresponding block scale. The
665+ """
666+ assert input .ndim >= 1 , (
667+ f'input.ndim needs to be >= 1, but got { input .ndim } .' )
668+ other_dims = 1 if input .ndim == 1 else - 1
669+ input = input .reshape (other_dims , input .shape [- 1 ])
670+ m , n = input .shape
671+ block_size = 16
672+ device = input .device
673+
674+ assert n % block_size == 0 , (
675+ f'last dim has to be multiple of 16, but got { n } .' )
676+ assert input .dtype in (torch .float16 , torch .bfloat16 ), (
677+ f'input.dtype needs to be fp16 or bf16 but got { input .dtype } .' )
678+
679+ # Two fp4 values will be packed into an uint8.
680+ output = torch .empty ((m , n // 2 ), device = device , dtype = torch .uint8 )
681+
682+ # We use the rounded values to store the swizzled values. Then, the scaling
683+ # factors in float8_e4m3fn are packed into an int32 for every 4 values.
684+ rounded_m = pad_up_fn (m , 128 )
685+ scale_n = n // block_size
686+ rounded_n = pad_up_fn (scale_n , 4 )
687+ block_scale_out = torch .empty ((rounded_m , rounded_n // 4 ),
688+ device = device ,
689+ dtype = torch .int32 )
690+ torch .ops ._C .scaled_fp4_quant (output , input , block_scale_out , global_scale ,
691+ False )
692+ return output , block_scale_out
693+
694+
695+ def blockscale_interleave (input : torch .Tensor ) -> torch .Tensor :
696+ """
697+ This method takes in `input` scale and returns an interleaved
698+ version of itself. The output `interleaved_block_scale` may
699+ return a padded version.
700+ """
701+ blockScaleShape = input .size ()
702+
703+ # Check if the tensor is 2D or 3D
704+ if len (blockScaleShape ) != 2 and len (blockScaleShape ) != 3 :
705+ raise ValueError ("Block Scale should be a 2D or 3D tensor." )
706+
707+ # Extract dimensions based on whether the tensor is 2D or 3D
708+ num_experts = blockScaleShape [0 ] if len (blockScaleShape ) == 3 else 1
709+ rows = blockScaleShape [1 ] if len (
710+ blockScaleShape ) == 3 else blockScaleShape [0 ]
711+ cols = blockScaleShape [2 ] if len (
712+ blockScaleShape ) == 3 else blockScaleShape [1 ]
713+
714+ expert_out_size = pad_up_fn (rows , 128 ) * pad_up_fn (cols , 4 )
715+ interleaved_block_scale = torch .zeros (expert_out_size * num_experts ,
716+ dtype = torch .int8 ,
717+ device = input .device )
718+
719+ torch .ops ._C .blockscale_interleave (interleaved_block_scale , input , rows ,
720+ cols , num_experts , expert_out_size )
721+ return interleaved_block_scale
659722
660723
661724# aqlm
0 commit comments