From 0bdde92114b470823aa24725bf3b0811e980c8ce Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 20 Sep 2024 13:17:23 -0700 Subject: [PATCH] Rename Floating point to fp8 (#909) --- torchao/dtypes/affine_quantized_tensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 3c1c4b52f..ecc8aa10d 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1360,7 +1360,7 @@ def _linear_f16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): return out.view(*act.shape[:-1], out_dim).to(act.dtype) -def _linear_fp_act_fp8_weight_check( +def _linear_fp8_act_fp8_weight_check( input_tensor: Union[torch.Tensor, AffineQuantizedTensor], weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], bias: Optional[torch.Tensor], @@ -1384,7 +1384,7 @@ def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): return input_scale -def _linear_fp_act_fp8_weight_impl( +def _linear_fp8_act_fp8_weight_impl( input_tensor: AffineQuantizedTensor, weight_tensor: AffineQuantizedTensor, bias: Optional[torch.Tensor], @@ -1473,7 +1473,7 @@ def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), - (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), + (_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), (_linear_f16_act_floatx_weight_check, _linear_f16_act_floatx_weight_impl),