@@ -751,13 +751,37 @@ def nvfp4_quantize(
751751 AssertionError: If input dtype is not supported, tensor size is not
752752 divisible by block_size, tensor is not contiguous, or block_size != 16
753753 """
754+ return _nvfp4_quantize (data_hp , block_size , per_tensor_scale )
755+
756+
757+ class _Float8Round (torch .autograd .Function ):
758+ """
759+ Cast a tensor to float8 and back to float32 with backward STE.
760+ """
761+
762+ @staticmethod
763+ def forward (ctx , x : torch .Tensor ) -> torch .Tensor :
764+ return x .to (torch .float8_e4m3fn ).to (torch .float32 )
765+
766+ @staticmethod
767+ def backward (ctx , gy : torch .Tensor ) -> torch .Tensor :
768+ return gy
769+
770+
771+ def _nvfp4_quantize (
772+ data_hp : torch .Tensor ,
773+ block_size : int = 16 ,
774+ per_tensor_scale : Optional [torch .Tensor ] = None ,
775+ skip_dtype_cast_and_packing : bool = False ,
776+ ) -> tuple [torch .Tensor , torch .Tensor ]:
754777 assert data_hp .dtype in (torch .bfloat16 , torch .float ), (
755778 f"{ data_hp .dtype } not supported"
756779 )
757780 assert data_hp .size (- 1 ) % block_size == 0 , "K dim must be divisible by block_size"
758781 assert data_hp .is_contiguous (), "Only support contiguous data for now"
759782 assert block_size == 16 , "NVFP4 requires block_size=16"
760783
784+ orig_dtype = data_hp .dtype
761785 orig_shape = data_hp .shape
762786 # Convert to float32 early for consistent precision with Triton implementation
763787 data_hp = data_hp .float ().reshape (orig_shape [0 ], - 1 , block_size )
@@ -769,10 +793,8 @@ def nvfp4_quantize(
769793 out_scales = None
770794 if per_tensor_scale is None :
771795 # We are doing single level scaling
772- block_scale_fp8 = torch .clamp (block_scale , min = E4M3_EPS , max = F8E4M3_MAX ).to (
773- torch .float8_e4m3fn
774- )
775- block_scale_fp32 = block_scale_fp8 .to (torch .float32 )
796+ block_scale_fp8 = torch .clamp (block_scale , min = E4M3_EPS , max = F8E4M3_MAX )
797+ block_scale_fp32 = _Float8Round .apply (block_scale_fp8 )
776798 data_scaled = data_hp / block_scale_fp32 .unsqueeze (- 1 )
777799 out_scales = block_scale_fp8
778800 else :
@@ -784,8 +806,8 @@ def nvfp4_quantize(
784806 scaled_block_scales = block_scale_fp32 / per_tensor_scale
785807 scaled_block_scales_fp8 = torch .clamp (
786808 scaled_block_scales , min = E4M3_EPS , max = F8E4M3_MAX
787- ). to ( torch . float8_e4m3fn )
788- scaled_block_scales_fp32 = scaled_block_scales_fp8 . to ( torch . float32 )
809+ )
810+ scaled_block_scales_fp32 = _Float8Round . apply ( scaled_block_scales_fp8 )
789811 # We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
790812 # To apply to data
791813 total_scale = per_tensor_scale * scaled_block_scales_fp32
@@ -794,8 +816,11 @@ def nvfp4_quantize(
794816
795817 data_scaled = torch .clamp (data_scaled , - F4_E2M1_MAX , F4_E2M1_MAX )
796818 data_scaled = data_scaled .view (orig_shape )
797- data_lp = f32_to_f4_unpacked (data_scaled )
798- # TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
799- # data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
800- data_lp = pack_uint4 (data_lp )
801- return out_scales , data_lp
819+ if skip_dtype_cast_and_packing :
820+ return out_scales .to (torch .float32 ), data_scaled .to (orig_dtype )
821+ else :
822+ data_lp = f32_to_f4_unpacked (data_scaled )
823+ # TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
824+ # data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
825+ data_lp = pack_uint4 (data_lp )
826+ return out_scales .to (torch .float8_e4m3fn ), data_lp
0 commit comments