diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 3c1c4b52f..ecc8aa10d 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1360,7 +1360,7 @@ def _linear_f16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): return out.view(*act.shape[:-1], out_dim).to(act.dtype) -def _linear_fp_act_fp8_weight_check( +def _linear_fp8_act_fp8_weight_check( input_tensor: Union[torch.Tensor, AffineQuantizedTensor], weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], bias: Optional[torch.Tensor], @@ -1384,7 +1384,7 @@ def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): return input_scale -def _linear_fp_act_fp8_weight_impl( +def _linear_fp8_act_fp8_weight_impl( input_tensor: AffineQuantizedTensor, weight_tensor: AffineQuantizedTensor, bias: Optional[torch.Tensor], @@ -1473,7 +1473,7 @@ def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), - (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), + (_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), (_linear_f16_act_floatx_weight_check, _linear_f16_act_floatx_weight_impl),