diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 9b2bd087d8e..94ab6de0e29 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -144,7 +144,6 @@ executorch_generated_lib( visibility = ["PUBLIC"], deps = [ "//executorch/backends/cadence/generic/kernels:cadence_kernels", - # Individual operator targets instead of combined cadence_generic_ops "//executorch/backends/cadence/generic/operators:op_requantize_out", "//executorch/backends/cadence/generic/operators:im2row_out", "//executorch/backends/cadence/generic/operators:dequantize_per_tensor", diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 95c35055e9c..2e9e187168f 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -184,12 +184,60 @@ - arg_meta: null kernel_name: impl::generic::quantize_per_tensor_out +- func: cadence::quantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::quantize_per_tensor_asym8s_out + +- func: cadence::quantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::quantize_per_tensor_asym8u_out + +- func: cadence::quantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::quantize_per_tensor_asym16s_out + +- func: cadence::quantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::quantize_per_tensor_asym16u_out + - func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null kernel_name: impl::generic::dequantize_per_tensor_out +- func: cadence::dequantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::dequantize_per_tensor_asym8s_out + +- func: cadence::dequantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::dequantize_per_tensor_asym8u_out + +- func: cadence::dequantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::dequantize_per_tensor_asym16s_out + +- func: cadence::dequantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::generic::dequantize_per_tensor_asym16u_out + - func: cadence::quantized_conv2d_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index a0e84d94300..c48aac8686a 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -284,12 +284,61 @@ - arg_meta: null kernel_name: impl::HiFi::quantize_per_tensor_out +- func: cadence::quantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantize_per_tensor_asym8s_out + +- func: cadence::quantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantize_per_tensor_asym8u_out + +- func: cadence::quantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantize_per_tensor_asym16s_out + +- func: cadence::quantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::quantize_per_tensor_asym16s_out + + - func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null kernel_name: impl::HiFi::dequantize_per_tensor_out +- func: cadence::dequantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::dequantize_per_tensor_asym8s_out + +- func: cadence::dequantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::dequantize_per_tensor_asym8u_out + +- func: cadence::dequantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::dequantize_per_tensor_asym16s_out + +- func: cadence::dequantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: impl::HiFi::dequantize_per_tensor_asym16u_out + - func: cadence::quantized_conv2d_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index e483bea79d1..567d86af457 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -28,12 +28,64 @@ "quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantize_per_tensor_asym8s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "quantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) + +lib.define( + "quantize_per_tensor_asym8u(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "quantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) + +lib.define( + "quantize_per_tensor_asym16s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "quantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) + +lib.define( + "quantize_per_tensor_asym16u(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "quantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) + lib.define( "dequantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" ) lib.define( "dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "dequantize_per_tensor_asym8s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "dequantize_per_tensor_asym8s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "dequantize_per_tensor_asym8u(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "dequantize_per_tensor_asym8u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "dequantize_per_tensor_asym16s(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "dequantize_per_tensor_asym16s.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "dequantize_per_tensor_asym16u(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "dequantize_per_tensor_asym16u.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define( "quantized_layer_norm(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)" @@ -541,6 +593,54 @@ def quantize_per_tensor_meta( return input.new_empty(input.size(), dtype=dtype) +@register_fake("cadence::quantize_per_tensor_asym8s") +def quantize_per_tensor_asym8s_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=dtype) + + +@register_fake("cadence::quantize_per_tensor_asym8u") +def quantize_per_tensor_asym8u_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=dtype) + + +@register_fake("cadence::quantize_per_tensor_asym16s") +def quantize_per_tensor_asym16s_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=dtype) + + +@register_fake("cadence::quantize_per_tensor_asym16u") +def quantize_per_tensor_asym16u_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=dtype) + + @register_fake("cadence::dequantize_per_tensor") def dequantize_per_tensor_meta( input: torch.Tensor, @@ -553,6 +653,54 @@ def dequantize_per_tensor_meta( return input.new_empty(input.size(), dtype=torch.float) +@register_fake("cadence::dequantize_per_tensor_asym8s") +def dequantize_per_tensor_asym8s_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=torch.float) + + +@register_fake("cadence::dequantize_per_tensor_asym8u") +def dequantize_per_tensor_asym8u_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=torch.float) + + +@register_fake("cadence::dequantize_per_tensor_asym16s") +def dequantize_per_tensor_asym16s_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=torch.float) + + +@register_fake("cadence::dequantize_per_tensor_asym16u") +def dequantize_per_tensor_asym16u_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=torch.float) + + @register_fake("cadence::quantized_add") def quantized_add_meta( X: torch.Tensor, diff --git a/backends/cadence/aot/type_dispatch.py b/backends/cadence/aot/type_dispatch.py index 3bf86ad2e50..97a25938e8d 100644 --- a/backends/cadence/aot/type_dispatch.py +++ b/backends/cadence/aot/type_dispatch.py @@ -27,6 +27,7 @@ class OpConfig: base_name: str type_dispatch_suffixes: dict[tuple[torch.dtype, ...], str] weight_arg_idx: Optional[int] = None + is_quant_op: bool = False variant: str = "per_tensor" @@ -100,6 +101,27 @@ class CompileTimeTypeDispatchPass(ExportPass): }, variant="default", ), + exir_ops.edge.cadence.quantize_per_tensor.default: OpConfig( + "quantize_per_tensor", + type_dispatch_suffixes={ + (torch.int8,): "asym8s", + (torch.uint8,): "asym8u", + (torch.int16,): "asym16s", + (torch.uint16,): "asym16s", + }, + variant="default", + is_quant_op=True, + ), + exir_ops.edge.cadence.dequantize_per_tensor.default: OpConfig( + "dequantize_per_tensor", + type_dispatch_suffixes={ + (torch.int8,): "asym8s", + (torch.uint8,): "asym8u", + (torch.int16,): "asym16s", + (torch.uint16,): "asym16s", + }, + variant="default", + ), } def call_operator( @@ -120,6 +142,8 @@ def call_operator( if config.weight_arg_idx is not None: weight_dtype = args[config.weight_arg_idx].to_tensor().dtype dtype_key = (input_dtype, weight_dtype) + elif config.is_quant_op: + dtype_key = (args[5],) else: dtype_key = (input_dtype,) diff --git a/backends/cadence/generic/operators/dequantize_per_tensor.cpp b/backends/cadence/generic/operators/dequantize_per_tensor.cpp index 1481981ee0b..aedc6e10309 100644 --- a/backends/cadence/generic/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/generic/operators/dequantize_per_tensor.cpp @@ -18,7 +18,7 @@ using ::executorch::aten::Tensor; using ::executorch::runtime::KernelRuntimeContext; using ::impl::generic::kernels::dequantize; -void dequantize_per_tensor_out( +Tensor& dequantize_per_tensor_out( KernelRuntimeContext& context, const Tensor& input, double scale, @@ -50,6 +50,71 @@ void dequantize_per_tensor_out( "Unhandled input dtype %hhd", static_cast(input.scalar_type())); } + return out; +} + +Tensor& dequantize_per_tensor_asym8s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const int8_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); + return out; +} + +Tensor& dequantize_per_tensor_asym8u_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const uint8_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); + return out; +} + +Tensor& dequantize_per_tensor_asym16s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const int16_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); + return out; +} + +Tensor& dequantize_per_tensor_asym16u_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const uint16_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); + return out; } } // namespace native diff --git a/backends/cadence/generic/operators/quantize_per_tensor.cpp b/backends/cadence/generic/operators/quantize_per_tensor.cpp index 29b233dab09..f2a413be35d 100644 --- a/backends/cadence/generic/operators/quantize_per_tensor.cpp +++ b/backends/cadence/generic/operators/quantize_per_tensor.cpp @@ -20,7 +20,7 @@ using ::impl::generic::kernels::quantize; // Quantize the input tensor (PT2 version). Note that quant_ are not // used in any computation. -void quantize_per_tensor_out( +Tensor& quantize_per_tensor_out( KernelRuntimeContext& context, const Tensor& input, double scale, @@ -34,30 +34,91 @@ void quantize_per_tensor_out( if (out.scalar_type() == ScalarType::Byte) { uint8_t* out_data = out.mutable_data_ptr(); - impl::generic::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); + quantize(out_data, input_data, 1. / scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Char) { int8_t* out_data = out.mutable_data_ptr(); - impl::generic::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); + quantize(out_data, input_data, 1. / scale, zero_point, numel); } else if ( out.scalar_type() == ScalarType::Bits16 || out.scalar_type() == ScalarType::UInt16) { uint16_t* out_data = out.mutable_data_ptr(); - impl::generic::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); + quantize(out_data, input_data, 1. / scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Short) { int16_t* out_data = out.mutable_data_ptr(); - impl::generic::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); + quantize(out_data, input_data, 1. / scale, zero_point, numel); } else { ET_CHECK_MSG( false, "Unhandled input dtype %hhd", static_cast(out.scalar_type())); } + return out; } -} // namespace native -} // namespace generic -} // namespace impl +Tensor& quantize_per_tensor_asym8s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + int8_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); + return out; +} + +Tensor& quantize_per_tensor_asym8u_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + uint8_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); + return out; +} + +Tensor& quantize_per_tensor_asym16s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + int16_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); + return out; +} + +Tensor& quantize_per_tensor_asym16u_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + uint16_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); + return out; +} + +}; // namespace native +}; // namespace generic +}; // namespace impl diff --git a/backends/cadence/generic/operators/targets.bzl b/backends/cadence/generic/operators/targets.bzl index 193b43c2b6d..fa0f128b229 100644 --- a/backends/cadence/generic/operators/targets.bzl +++ b/backends/cadence/generic/operators/targets.bzl @@ -44,6 +44,7 @@ def define_common_targets(): ], visibility = [ "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", ], ) diff --git a/backends/cadence/hifi/operators/op_dequantize_per_tensor.cpp b/backends/cadence/hifi/operators/op_dequantize_per_tensor.cpp index f416082b10f..317e7ed8ef9 100644 --- a/backends/cadence/hifi/operators/op_dequantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/op_dequantize_per_tensor.cpp @@ -53,6 +53,51 @@ void dequantize_per_tensor_out( } } +void dequantize_per_tensor_asym8u_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const uint8_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); +} + +void dequantize_per_tensor_asym16s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const int16_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); +} + +void dequantize_per_tensor_asym16u_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + size_t numel = out.numel(); + const uint16_t* input_data = input.const_data_ptr(); + dequantize(out_data, input_data, scale, zero_point, numel); +} + } // namespace native } // namespace HiFi } // namespace impl diff --git a/backends/cadence/hifi/operators/op_dequantize_per_tensor_asym8s.cpp b/backends/cadence/hifi/operators/op_dequantize_per_tensor_asym8s.cpp new file mode 100644 index 00000000000..d1099b1a4db --- /dev/null +++ b/backends/cadence/hifi/operators/op_dequantize_per_tensor_asym8s.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace impl { +namespace HiFi { +namespace native { + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +void dequantize_per_tensor_asym8s_out( + KernelRuntimeContext& ctx, + const Tensor& input, + double scale, + int64_t zero_point, + __ET_UNUSED int64_t quant_min, + __ET_UNUSED int64_t quant_max, + ScalarType dtype, + Tensor& out) { + float* out_data = out.mutable_data_ptr(); + const size_t numel = out.numel(); + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym8s_f32( + out_data, input_data, zero_point, scale, numel); +} + +}; // namespace native +}; // namespace HiFi +}; // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp b/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp index b2f47619f05..9bc3d48699e 100644 --- a/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp @@ -19,10 +19,13 @@ namespace impl { namespace HiFi { namespace native { + namespace { + using ::executorch::aten::ScalarType; using ::executorch::aten::Tensor; using ::executorch::runtime::KernelRuntimeContext; +using ::impl::HiFi::kernels::quantize; // Add checks for dtype quant min/max bounds. template @@ -92,22 +95,19 @@ void quantize_per_tensor_out( const size_t numel = out.numel(); if (out.scalar_type() == ScalarType::Byte) { uint8_t* out_data = out.mutable_data_ptr(); - impl::HiFi::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); + quantize(out_data, input_data, 1. / scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Char) { int8_t* out_data = out.mutable_data_ptr(); xa_nn_elm_quantize_f32_asym8s( out_data, input_data, scale, zero_point, numel); } else if (out.scalar_type() == ScalarType::Short) { int16_t* out_data = out.mutable_data_ptr(); - impl::HiFi::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); + quantize(out_data, input_data, 1. / scale, zero_point, numel); } else if ( out.scalar_type() == ScalarType::Bits16 || out.scalar_type() == ScalarType::UInt16) { uint16_t* out_data = out.mutable_data_ptr(); - impl::HiFi::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); + quantize(out_data, input_data, 1. / scale, zero_point, numel); } else { ET_KERNEL_CHECK_MSG( ctx, @@ -119,6 +119,51 @@ void quantize_per_tensor_out( } } -} // namespace native -} // namespace HiFi -} // namespace impl +void quantize_per_tensor_asym8u_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + uint8_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); +} + +void quantize_per_tensor_asym16s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + int16_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); +} + +void quantize_per_tensor_asym16u_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + uint16_t* out_data = out.mutable_data_ptr(); + quantize(out_data, input_data, 1. / scale, zero_point, numel); +} + +}; // namespace native +}; // namespace HiFi +}; // namespace impl diff --git a/backends/cadence/hifi/operators/op_quantize_per_tensor_asym8s.cpp b/backends/cadence/hifi/operators/op_quantize_per_tensor_asym8s.cpp new file mode 100644 index 00000000000..552b6acf150 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantize_per_tensor_asym8s.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +namespace impl { +namespace HiFi { +namespace native { + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +void quantize_per_tensor_asym8s_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + const float* input_data = input.const_data_ptr(); + size_t numel = out.numel(); + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym8s(out_data, input_data, scale, zero_point, numel); +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/targets.bzl b/backends/cadence/hifi/operators/targets.bzl index ca474e8183b..1f9814c4a4e 100644 --- a/backends/cadence/hifi/operators/targets.bzl +++ b/backends/cadence/hifi/operators/targets.bzl @@ -44,6 +44,7 @@ OPERATORS = [ "cat", "clamp", "dequantize_per_tensor", + "dequantize_per_tensor_asym8s", "div", "embedding", "eq", @@ -95,6 +96,7 @@ OPERATORS = [ "quantized_relu_asym8s_asym8s_per_tensor_out", "quantized_relu_asym8u_asym8u_per_tensor_out", "quantize_per_tensor", + "quantize_per_tensor_asym8s", "remainder", "rsqrt", "select_copy",