diff --git a/src/common/convolution_pd.hpp b/src/common/convolution_pd.hpp index ee6f631b405..80b619d8d58 100644 --- a/src/common/convolution_pd.hpp +++ b/src/common/convolution_pd.hpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright 2016-2024 Intel Corporation +* Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -233,8 +234,16 @@ struct convolution_pd_t : public primitive_desc_t { || invariant_dst_md()->data_type == dst_dt) && (acc_dt == data_type::undef || desc_.accum_data_type == acc_dt); - if (with_bias() && bia_dt != data_type::undef) + if (with_bias() && bia_dt != data_type::undef) { +#ifdef __aarch64__ + // ACL only supports s32 bias for quantization. Therefore internally + // we convert from f32 to s32. So here the types doesn't match. + if (utils::one_of( + dst_dt, data_type_t::dnnl_s8, data_type_t::dnnl_u8)) + return ok; +#endif ok = ok && invariant_bia_md()->data_type == bia_dt; + } return ok; } diff --git a/src/common/memory_tracking.hpp b/src/common/memory_tracking.hpp index 1eec393ad2f..83032db0689 100644 --- a/src/common/memory_tracking.hpp +++ b/src/common/memory_tracking.hpp @@ -179,6 +179,7 @@ enum { key_conv_amx_wsp_buffer, key_conv_bia_reduction, key_conv_bias_bf16_convert_wsp, + key_conv_bias_s32_convert, key_conv_cudnn, key_conv_cudnn_algo, key_conv_cudnn_filter, diff --git a/src/cpu/aarch64/acl_convolution_utils.cpp b/src/cpu/aarch64/acl_convolution_utils.cpp index 15437746069..16a86856d82 100644 --- a/src/cpu/aarch64/acl_convolution_utils.cpp +++ b/src/cpu/aarch64/acl_convolution_utils.cpp @@ -65,8 +65,13 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, everyone_is(data_type::f16, src_d.data_type(), wei_d.data_type(), dst_d.data_type()), everyone_is(data_type::bf16, src_d.data_type(), - wei_d.data_type(), dst_d.data_type())), - " src, dst and wei must be fp16, bf16 or fp32"); + wei_d.data_type(), dst_d.data_type()), + everyone_is(data_type::s8, src_d.data_type(), + wei_d.data_type(), dst_d.data_type()), + (everyone_is(data_type::u8, src_d.data_type(), + dst_d.data_type()) + && wei_d.data_type() == data_type::s8)), + " src, dst and wei must be s8, u8, bf16, fp16 or fp32"); // batch size const int mb = src_d.dims()[0]; @@ -165,7 +170,8 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, : arm_compute::DataLayout::NCHW; // all have the same datatype - auto acl_data_type = acl_utils::get_acl_data_t(src_d.data_type()); + auto acl_data_type + = acl_utils::get_acl_data_t(src_d.data_type(), acp.is_quantized); // clang-format off acp.src_tensor_info = arm_compute::TensorInfo( @@ -179,8 +185,9 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, is_nhwc ? arm_compute::TensorShape(ic, kw, kh, oc) : arm_compute::TensorShape(kw, kh, ic, oc), 1, - acl_data_type, + acl_utils::get_acl_data_t(wei_d.data_type(), acp.is_quantized), acl_layout); + if(is_depthwise) { // We need to set that values are not constant so that we // we can update them in-place in ACL @@ -198,10 +205,20 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, acp.with_bias ? arm_compute::TensorShape(oc) : arm_compute::TensorShape(), 1, - acl_data_type, + acp.is_quantized ? acl_utils::get_acl_data_t(data_type::s32) : acl_data_type, acl_layout); // clang-format on + if (acp.is_quantized) { + // ACL rejects the operation if quantization information is empty during configuration. + // Since the correct parameters are not available at this stage, we provide placeholder values. + // These values are then updated with the correct ones during the run stage. + arm_compute::QuantizationInfo qi {1.0, 0, true}; + acp.src_tensor_info.set_quantization_info(qi); + acp.wei_tensor_info.set_quantization_info(qi); + acp.dst_tensor_info.set_quantization_info(qi); + } + // ACL Winograd is not prepared for fixed format kernels if (acp.alg_winograd) { const bool is_1d = ndims == 3; @@ -216,7 +233,7 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, // Are we allowed to cast down to bf16 or not? acp.fast_math = one_of(attr.fpmath_.mode_, fpmath_mode::bf16, fpmath_mode::any); - if (is_depthwise) { + if (is_depthwise || acp.is_quantized) { // There is no support for fixed format kernels for depthwise convolution // in ACL so we are going to use weight format that we set up earlier return status::success; diff --git a/src/cpu/aarch64/acl_convolution_utils.hpp b/src/cpu/aarch64/acl_convolution_utils.hpp index 37a3d6c3d98..fffaa7e4a73 100644 --- a/src/cpu/aarch64/acl_convolution_utils.hpp +++ b/src/cpu/aarch64/acl_convolution_utils.hpp @@ -20,9 +20,11 @@ #include #include "acl_post_ops.hpp" #include "acl_utils.hpp" -#include "arm_compute/runtime/experimental/operators/CpuDepthwiseConv2d.h" #include "cpu/cpu_convolution_pd.hpp" -#include +#include "cpu/cpu_primitive.hpp" + +#include "arm_compute/runtime/experimental/operators/CpuGemmConv2d.h" + namespace dnnl { namespace impl { namespace cpu { @@ -44,6 +46,8 @@ struct acl_conv_conf_t { // algorithm can be set to algorithm::convolution_auto and later on we need to // skip fixed-format protocol as ACL Winograd does not support it. bool alg_winograd; + // currently, only CpuGemmConv2d has the static quantization update interface. + bool is_quantized; arm_compute::TensorInfo src_tensor_info; arm_compute::TensorInfo wei_tensor_info; arm_compute::TensorInfo bia_tensor_info; @@ -70,11 +74,13 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, using conv_key_t = decltype(memory_tracking::names::key_gemm_tmp_buffer); template -status_t init_scratchpad(op_t &conv, memory_tracking::registrar_t &scratchpad, +status_t init_scratchpad(const op_t &conv, + memory_tracking::registrar_t &scratchpad, const std::map &conv_keys, engine_t *engine, post_ops_t &post_ops, dnnl::impl::post_ops_t &attr_post_ops, arm_compute::ActivationLayerInfo &act_info, bool &use_dst_acc_for_sum, - const dnnl::impl::memory_desc_t &dst_md) { + const dnnl::impl::memory_desc_t &dst_md, + const dnnl::impl::memory_desc_t &bias_md, const bool is_quantized) { // Book temp mem. const auto aux_mem_req = conv.workspace(); @@ -95,6 +101,12 @@ status_t init_scratchpad(op_t &conv, memory_tracking::registrar_t &scratchpad, dst_d.data_type_size()); } + if (is_quantized && bias_md.format_kind != format_kind::undef) { + const memory_desc_wrapper bias_d(&bias_md); + scratchpad.book(memory_tracking::names::key_conv_bias_s32_convert, + bias_d.nelems(), bias_d.data_type_size()); + } + return status::success; } @@ -102,7 +114,7 @@ template status_t execute_forward_conv_acl(const exec_ctx_t &ctx, - conv_obj_t *acl_conv_obj, const conv_pd_t *pd, + conv_obj_t *acl_conv_obj, const conv_pd_t *pd_, const std::map &conv_keys) { auto src_base = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); @@ -115,16 +127,49 @@ status_t execute_forward_conv_acl(const exec_ctx_t &ctx, arm_compute::Tensor bia_tensor = nullptr; arm_compute::Tensor dst_tensor; - auto const acp = pd->acp_; + auto const acp = pd_->acp_; src_tensor.allocator()->init(acp.src_tensor_info); wei_tensor.allocator()->init(acp.wei_tensor_info); dst_tensor.allocator()->init(acp.dst_tensor_info); + const auto scratchpad = ctx.get_scratchpad_grantor(); + + if (acp.is_quantized) { + // DEFINE_(ARG|ZERO)... demands 'pd' as a function + auto pd = [pd_] { return pd_; }; + + DEFINE_ARG_SCALES_BUFFER(src_scale, DNNL_ARG_SRC); + DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC); + DEFINE_ARG_SCALES_BUFFER(wei_scale, DNNL_ARG_WEIGHTS); + DEFINE_ZERO_POINT_VALUE(wei_zero_point, DNNL_ARG_WEIGHTS); + DEFINE_ARG_SCALES_BUFFER(dst_scale, DNNL_ARG_DST); + DEFINE_ZERO_POINT_VALUE(dst_zero_point, DNNL_ARG_DST); + + // s8s8s8 uses D = Sx*Sy*(XY + X*zy + Y*zx + zx*zy) and u8s8u8 uses D = Sx*Sy*(XW - X*zw - W*zx + zx*zw) + if (dst_tensor.info()->data_type() == arm_compute::DataType::QASYMM8) { + src_tensor.info()->set_quantization_info( + arm_compute::QuantizationInfo( + *src_scale, -src_zero_point, true)); + wei_tensor.info()->set_quantization_info( + arm_compute::QuantizationInfo( + *wei_scale, -wei_zero_point, true)); + } else { + src_tensor.info()->set_quantization_info( + arm_compute::QuantizationInfo( + *src_scale, src_zero_point, true)); + wei_tensor.info()->set_quantization_info( + arm_compute::QuantizationInfo( + *wei_scale, wei_zero_point, true)); + } + + // for efficiency reasons, OneDNN saves the inverse of the destination + dst_tensor.info()->set_quantization_info(arm_compute::QuantizationInfo( + 1.0 / (*dst_scale), dst_zero_point, true)); + } + src_tensor.allocator()->import_memory(const_cast(src_base)); wei_tensor.allocator()->import_memory(const_cast(wei_base)); - const auto scratchpad = ctx.get_scratchpad_grantor(); - // If we have an unfused sum post op, put the result in a scratchpad tensor. // Result will be summed to the dst during acl_post_ops.execute auto dst_base = acp.use_dst_acc_for_sum @@ -133,10 +178,30 @@ status_t execute_forward_conv_acl(const exec_ctx_t &ctx, dst_tensor.allocator()->import_memory(dst_base); if (acp.with_bias) { - auto bia_base = CTX_IN_MEM(const bia_data_t *, DNNL_ARG_BIAS); - bia_tensor.allocator()->init(acp.bia_tensor_info); - bia_tensor.allocator()->import_memory( - const_cast(bia_base)); + if (acp.is_quantized) { + auto bia_s32_base = scratchpad.get( + memory_tracking::names::key_conv_bias_s32_convert); + auto bia_f32_base = CTX_IN_MEM(const float32_t *, DNNL_ARG_BIAS); + auto src_scale + = src_tensor.info()->quantization_info().uniform().scale; + auto wei_scale + = wei_tensor.info()->quantization_info().uniform().scale; + const float bias_scale = 1 / (src_scale * wei_scale); + const int num_elements + = acp.bia_tensor_info.total_size() / sizeof(float32_t); + parallel_nd(num_elements, [&](dim_t e) { + const auto b + = int32_t(std::round(bia_f32_base[e] * bias_scale)); + bia_s32_base[e] = b; + }); + bia_tensor.allocator()->init(acp.bia_tensor_info); + bia_tensor.allocator()->import_memory(bia_s32_base); + } else { + auto bia_base = CTX_IN_MEM(const bia_data_t *, DNNL_ARG_BIAS); + bia_tensor.allocator()->init(acp.bia_tensor_info); + bia_tensor.allocator()->import_memory( + const_cast(bia_base)); + } } // Constness of the weight tensor matters for depthwise conv in ACL. @@ -167,10 +232,17 @@ status_t execute_forward_conv_acl(const exec_ctx_t &ctx, } } + if (acp.is_quantized) { + arm_compute::experimental::op::CpuGemmConv2d *conv + = dynamic_cast( + &acl_conv_obj->conv); + if (conv) conv->update_quantization_parameters(pack); + } + acl_conv_obj->conv.run(pack); void *dst = dst_tensor.buffer(); - pd->post_ops.execute(ctx, dst); + pd_->post_ops.execute(ctx, dst); return status::success; } diff --git a/src/cpu/aarch64/acl_depthwise_convolution.cpp b/src/cpu/aarch64/acl_depthwise_convolution.cpp index 4752cfd5852..d18c51c5a0d 100644 --- a/src/cpu/aarch64/acl_depthwise_convolution.cpp +++ b/src/cpu/aarch64/acl_depthwise_convolution.cpp @@ -74,7 +74,7 @@ status_t acl_depthwise_convolution_fwd_t::pd_t::init(engine_t *engine) { auto scratchpad = scratchpad_registry().registrar(); return init_scratchpad(conv, scratchpad, depthwise_conv_keys, engine, post_ops, attr_.post_ops_, acp_.act_info, acp_.use_dst_acc_for_sum, - dst_md_); + dst_md_, bias_md_, false); } status_t acl_depthwise_convolution_fwd_t::init(engine_t *engine) { diff --git a/src/cpu/aarch64/acl_gemm_convolution.cpp b/src/cpu/aarch64/acl_gemm_convolution.cpp index 5934fd24102..c172f4b395e 100644 --- a/src/cpu/aarch64/acl_gemm_convolution.cpp +++ b/src/cpu/aarch64/acl_gemm_convolution.cpp @@ -52,21 +52,31 @@ template ::pd_t::init( engine_t *engine) { using namespace data_type; - using smask_t = primitive_attr_t::skip_mask_t; bool ok = is_fwd() && set_default_alg_kind(alg_kind::convolution_direct) && expect_data_types(src_t, wei_t, bia_t, dst_t, undef) - && !has_zero_dim_memory() - && attr()->has_default_values( - smask_t::post_ops | smask_t::fpmath_mode, dst_t); + && !has_zero_dim_memory() && output_scales_mask_ok() + && zero_points_ok(); + if (!ok) return status::unimplemented; if (weights_md_.ndims != 4) return status::unimplemented; + // reject in case the op is running in a Neoverse-N1 with mixed sign quantization. + if (!arm_compute::CPUInfo::get().has_i8mm() && src_t == dnnl_u8 + && wei_t == dnnl_s8 && dst_t == dnnl_u8) + return status::unimplemented; + + // currently, only CpuGemmConv2d has the static quantization update interface. + acp_.is_quantized + = utils::one_of(dst_md_.data_type, data_type::s8, data_type::u8); + // General Compute Library checks, memory tags are also set there CHECK(acl_convolution_utils::acl_init_conf( acp_, src_md_, weights_md_, dst_md_, bias_md_, *desc(), *attr())); + CHECK(post_ops.init(engine, attr_.post_ops_, dst_md_, acp_.act_info)); + // Validate convolution manually to check for return status ACL_CHECK_VALID(Op::validate(&acp_.src_tensor_info, &acp_.wei_tensor_info, acp_.with_bias ? &acp_.bia_tensor_info : nullptr, @@ -82,7 +92,25 @@ status_t acl_gemm_convolution_fwd_t::pd_t::init( auto scratchpad = scratchpad_registry().registrar(); const auto mem_req = conv.workspace(); return init_scratchpad(conv, scratchpad, gemm_conv_keys, engine, post_ops, - attr_.post_ops_, acp_.act_info, acp_.use_dst_acc_for_sum, dst_md_); + attr_.post_ops_, acp_.act_info, acp_.use_dst_acc_for_sum, dst_md_, + bias_md_, acp_.is_quantized); +} + +template +bool acl_gemm_convolution_fwd_t::pd_t::output_scales_mask_ok() const { + int mask_src = attr()->scales_.get(DNNL_ARG_SRC).mask_; + int mask_wei = attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; + int mask_dst = attr()->scales_.get(DNNL_ARG_DST).mask_; + return mask_src == 0 && mask_wei == 0 && mask_dst == 0; +} + +template +bool acl_gemm_convolution_fwd_t::pd_t::zero_points_ok() const { + return attr()->zero_points_.common(); } template ::execute_forward( using namespace data_type; template struct acl_gemm_convolution_fwd_t; template struct acl_gemm_convolution_fwd_t; -template struct acl_gemm_convolution_fwd_t; +template struct acl_gemm_convolution_fwd_t; +template struct acl_gemm_convolution_fwd_t; } // namespace aarch64 } // namespace cpu diff --git a/src/cpu/aarch64/acl_gemm_convolution.hpp b/src/cpu/aarch64/acl_gemm_convolution.hpp index 6b40f0efff4..bb765988000 100644 --- a/src/cpu/aarch64/acl_gemm_convolution.hpp +++ b/src/cpu/aarch64/acl_gemm_convolution.hpp @@ -42,6 +42,8 @@ struct acl_gemm_convolution_fwd_t : public primitive_t { "gemm:acl", acl_gemm_convolution_fwd_t, USE_GLOBAL_SCRATCHPAD); status_t init(engine_t *engine); + bool output_scales_mask_ok() const; + bool zero_points_ok() const; acl_conv_conf_t acp_ = utils::zero(); acl_post_ops_t post_ops; diff --git a/src/cpu/aarch64/acl_indirect_gemm_convolution.cpp b/src/cpu/aarch64/acl_indirect_gemm_convolution.cpp index 44cd03620d9..51a7f5fa2f7 100644 --- a/src/cpu/aarch64/acl_indirect_gemm_convolution.cpp +++ b/src/cpu/aarch64/acl_indirect_gemm_convolution.cpp @@ -119,7 +119,7 @@ status_t acl_indirect_gemm_convolution_fwd_t::pd_t::init(engine_t *engine) { auto scratchpad = scratchpad_registry().registrar(); return init_scratchpad(conv, scratchpad, indirect_conv_keys, engine, post_ops, attr_.post_ops_, acp_.act_info, acp_.use_dst_acc_for_sum, - dst_md_); + dst_md_, bias_md_, false); } } // namespace aarch64 diff --git a/src/cpu/aarch64/acl_winograd_convolution.cpp b/src/cpu/aarch64/acl_winograd_convolution.cpp index ebdb99f50ed..9a2781ba596 100644 --- a/src/cpu/aarch64/acl_winograd_convolution.cpp +++ b/src/cpu/aarch64/acl_winograd_convolution.cpp @@ -75,7 +75,8 @@ status_t acl_wino_convolution_fwd_t::pd_t::init(engine_t *engine) { auto scratchpad = scratchpad_registry().registrar(); const auto aux_mem = conv.workspace(); return init_scratchpad(conv, scratchpad, wino_conv_keys, engine, post_ops, - attr_.post_ops_, acp_.act_info, acp_.use_dst_acc_for_sum, dst_md_); + attr_.post_ops_, acp_.act_info, acp_.use_dst_acc_for_sum, dst_md_, + bias_md_, acp_.is_quantized); } status_t acl_wino_convolution_fwd_t::init(engine_t *engine) { diff --git a/src/cpu/aarch64/matmul/acl_lowp_matmul_sq.cpp b/src/cpu/aarch64/matmul/acl_lowp_matmul_sq.cpp new file mode 100644 index 00000000000..7828a1dd75d --- /dev/null +++ b/src/cpu/aarch64/matmul/acl_lowp_matmul_sq.cpp @@ -0,0 +1,211 @@ +/******************************************************************************* +* Copyright 2024 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/aarch64/matmul/acl_lowp_matmul_sq.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { +namespace matmul { +status_t acl_lowp_matmul_sq_resource_t::configure( + const acl_lowp_matmul_sq_conf_t &almc) { + if (!acl_obj_) return status::out_of_memory; + acl_obj_->src_tensor.allocator()->init(almc.src_tensor_info); + acl_obj_->wei_tensor.allocator()->init(almc.wei_tensor_info); + if (almc.with_bias) { + acl_obj_->bia_tensor.allocator()->init(almc.bia_tensor_info); + } + acl_obj_->dst_tensor.allocator()->init(almc.dst_tensor_info); + arm_compute::QuantizationInfo qi {1.0, 0, true}; + acl_obj_->src_tensor.info()->set_quantization_info(qi); + acl_obj_->wei_tensor.info()->set_quantization_info(qi); + acl_obj_->dst_tensor.info()->set_quantization_info(qi); + acl_obj_->gemm.configure(&acl_obj_->src_tensor, &acl_obj_->wei_tensor, + almc.with_bias ? &acl_obj_->bia_tensor : nullptr, + &acl_obj_->dst_tensor, almc.gemm_info); + return status::success; +} +status_t acl_lowp_matmul_sq_t::pd_t::init(engine_t *engine) { + VDISPATCH_MATMUL(set_default_formats(), "failed to set default formats"); + using smask_t = primitive_attr_t::skip_mask_t; + VDISPATCH_MATMUL( + attr()->has_default_values(smask_t::scales_runtime + | smask_t::zero_points_runtime | smask_t::post_ops), + "only scale, zero point and post-ops attrs supported"); + VDISPATCH_MATMUL(attr()->scales_.get(DNNL_ARG_SRC).mask_ == 0 + && attr()->zero_points_.get(DNNL_ARG_SRC) == 0 + && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0 + && attr()->zero_points_.get(DNNL_ARG_WEIGHTS) == 0 + && attr()->scales_.get(DNNL_ARG_DST).mask_ == 0 + && attr()->zero_points_.get(DNNL_ARG_DST) == 0, + "common scales and zero points only"); + VDISPATCH_MATMUL( + !has_runtime_dims_or_strides(), VERBOSE_RUNTIMEDIM_UNSUPPORTED); + const memory_desc_wrapper src_d(src_md_); + const memory_desc_wrapper wei_d(weights_md_); + const memory_desc_wrapper bia_d(bias_md_); + const memory_desc_wrapper dst_d(dst_md_); + using namespace data_type; + VDISPATCH_MATMUL(utils::one_of(src_d.data_type(), s8, u8) + && wei_d.data_type() == s8 + && src_d.data_type() == s8 + ? dst_d.data_type() == s8 + : dst_d.data_type() == u8, + VERBOSE_UNSUPPORTED_DT_CFG); + VDISPATCH_MATMUL(utils::one_of(bia_d.data_type(), f32, undef), + VERBOSE_UNSUPPORTED_DT_CFG); + // reject in case the op is running in a Neoverse-N1. + VDISPATCH_MATMUL(arm_compute::CPUInfo::get().has_i8mm(), + "Neoverse-N1 not supported"); + VDISPATCH_MATMUL(src_d.matches_tag(format_tag::ab) + && wei_d.matches_tag(format_tag::ab) + && dst_d.matches_tag(format_tag::ab), + VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_MATMUL_SC( + memory_desc_init_by_tag(bias_md_, bias_md_.ndims, bias_md_.dims, + bias_md_.data_type, format_tag::ab), + VERBOSE_UNSUPPORTED_BIAS_CFG); + // We set the QuantizationInfo to be dynamic because it is re-set in run() + almc_.src_tensor_info + = arm_compute::TensorInfo(arm_compute::TensorShape(K(), M()), 1, + acl_utils::get_acl_data_t(src_d.data_type(), true), + arm_compute::QuantizationInfo(1.0, 0, true)); + almc_.src_tensor_info.set_are_values_constant(false); + almc_.wei_tensor_info + = arm_compute::TensorInfo(arm_compute::TensorShape(N(), K()), 1, + acl_utils::get_acl_data_t(wei_d.data_type(), true), + arm_compute::QuantizationInfo(1.0, 0, true)); + almc_.wei_tensor_info.set_are_values_constant(false); + almc_.dst_tensor_info + = arm_compute::TensorInfo(arm_compute::TensorShape(N(), M()), 1, + acl_utils::get_acl_data_t(dst_d.data_type(), true), + arm_compute::QuantizationInfo(1.0, 0, true)); + almc_.bia_tensor_info = arm_compute::TensorInfo( + arm_compute::TensorShape(), 1, arm_compute::DataType::S32); + almc_.with_bias = bia_d.format_kind() != format_kind::undef; + if (almc_.with_bias) { + // This is not currently guarded in ACL + VDISPATCH_MATMUL(bia_d.ndims() == 2 && bia_d.dims()[0] == 1 + && bia_d.dims()[1] == N(), + "Only 1xN bias is supported"); + almc_.bia_tensor_info.set_tensor_shape( + arm_compute::TensorShape(bia_d.dims()[1], bia_d.dims()[0])); + } + arm_compute::GEMMLowpOutputStageInfo info; + info.type = arm_compute::GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT; + info.gemmlowp_multiplier = 1073741824; + info.gemmlowp_shift = -1; + info.gemmlowp_offset = 0; + info.gemmlowp_min_bound = -128; + info.gemmlowp_max_bound = 127; + info.output_data_type = almc_.dst_tensor_info.data_type(); + almc_.gemm_info.set_gemmlowp_output_stage(info); + auto scratchpad = scratchpad_registry().registrar(); + const dnnl::impl::memory_desc_t dst_md_ {desc_.dst_desc}; + arm_compute::ActivationLayerInfo act_info; + CHECK(init_scratchpad(engine, scratchpad, acl_post_ops, attr_.post_ops_, + act_info, dst_md_)); + almc_.gemm_info.set_activation_info(act_info); + ACL_CHECK_VALID(arm_compute::NEGEMMLowpMatrixMultiplyCore::validate( + &almc_.src_tensor_info, &almc_.wei_tensor_info, + almc_.with_bias ? &almc_.bia_tensor_info : nullptr, + &almc_.dst_tensor_info, almc_.gemm_info)); + return status::success; +} +status_t acl_lowp_matmul_sq_t::pd_t::init_scratchpad(engine_t *engine, + memory_tracking::registrar_t &scratchpad, acl_post_ops_t &post_ops, + dnnl::impl::post_ops_t &attr_post_ops, + arm_compute::ActivationLayerInfo &act_info, + const dnnl::impl::memory_desc_t &dst_md) { + CHECK(post_ops.init(engine, attr_post_ops, dst_md, act_info)); + // ACL only accepts s32 bias for quantization and since + // the current bias vector is f32 we need to convert. + if (almc_.with_bias) { + const memory_desc_wrapper bias_d(&bias_md_); + scratchpad.book(memory_tracking::names::key_conv_bias_s32_convert, + bias_d.nelems(), bias_d.data_type_size()); + } + return status::success; +} +status_t acl_lowp_matmul_sq_t::create_resource( + engine_t *engine, resource_mapper_t &mapper) const { + if (mapper.has_resource(this)) return status::success; + auto r = utils::make_unique(); + if (!r) return status::out_of_memory; + CHECK(r->configure(pd()->almc_)); + mapper.add(this, std::move(r)); + return status::success; +} +status_t acl_lowp_matmul_sq_t::execute(const exec_ctx_t &ctx) const { + std::lock_guard _lock {this->mtx}; + bool with_bias = pd()->almc_.with_bias; + acl_lowp_matmul_sq_obj_t &acl_obj + = ctx.get_resource_mapper() + ->get(this) + ->get_acl_obj(); + auto src = CTX_IN_MEM(const int8_t *, DNNL_ARG_SRC); + auto wei = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS); + auto dst = CTX_OUT_MEM(const int8_t *, DNNL_ARG_DST); + acl_obj.src_tensor.allocator()->import_memory(const_cast(src)); + acl_obj.wei_tensor.allocator()->import_memory(const_cast(wei)); + acl_obj.dst_tensor.allocator()->import_memory(const_cast(dst)); + DEFINE_ARG_SCALES_BUFFER(src_scale, DNNL_ARG_SRC); + DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC); + DEFINE_ARG_SCALES_BUFFER(wei_scale, DNNL_ARG_WEIGHTS); + DEFINE_ZERO_POINT_VALUE(wei_zero_point, DNNL_ARG_WEIGHTS); + DEFINE_ARG_SCALES_BUFFER(dst_scale, DNNL_ARG_DST); + DEFINE_ZERO_POINT_VALUE(dst_zero_point, DNNL_ARG_DST); + if (with_bias) { + const auto scratchpad = ctx.get_scratchpad_grantor(); + auto bia_s32_base = scratchpad.get( + memory_tracking::names::key_conv_bias_s32_convert); + auto bia_f32_base = CTX_IN_MEM(const float32_t *, DNNL_ARG_BIAS); + const float bias_scale = 1 / (*src_scale * (*wei_scale)); + const int num_elements + = acl_obj.bia_tensor.info()->total_size() / sizeof(float32_t); + parallel_nd(num_elements, [&](dim_t e) { + const auto b = int32_t(std::round(bia_f32_base[e] * bias_scale)); + bia_s32_base[e] = b; + }); + acl_obj.bia_tensor.allocator()->init(*acl_obj.bia_tensor.info()); + acl_obj.bia_tensor.allocator()->import_memory(bia_s32_base); + } + acl_obj.src_tensor.info()->set_quantization_info( + arm_compute::QuantizationInfo(*src_scale, -src_zero_point, true)); + acl_obj.wei_tensor.info()->set_quantization_info( + arm_compute::QuantizationInfo(*wei_scale, -wei_zero_point, true)); + // for efficiency reasons, OneDNN saves the inverse of the destination + acl_obj.dst_tensor.info()->set_quantization_info( + arm_compute::QuantizationInfo( + 1.0 / (*dst_scale), dst_zero_point, true)); + // The two calls below are stateful and, therefore, not fully thread-safe. + // This issue is being addressed, and the lock will be removed when the + // matmul stateless work is finished. + acl_obj.gemm.update_quantization_parameters(); + acl_obj.gemm.run(); + // free() here tells ACL it can no longer use it, it does not deallocate + acl_obj.src_tensor.allocator()->free(); + acl_obj.wei_tensor.allocator()->free(); + if (with_bias) { acl_obj.bia_tensor.allocator()->free(); } + acl_obj.dst_tensor.allocator()->free(); + return status::success; +}; +} // namespace matmul +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl \ No newline at end of file diff --git a/src/cpu/aarch64/matmul/acl_lowp_matmul_sq.hpp b/src/cpu/aarch64/matmul/acl_lowp_matmul_sq.hpp new file mode 100644 index 00000000000..9cfae6d1172 --- /dev/null +++ b/src/cpu/aarch64/matmul/acl_lowp_matmul_sq.hpp @@ -0,0 +1,111 @@ +/******************************************************************************* +* Copyright 2024 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ACL_LOWP_MATMUL_SQ_HPP +#define ACL_LOWP_MATMUL_SQ_HPP + +#include + +#include "cpu/cpu_primitive.hpp" +#include "cpu/matmul/cpu_matmul_pd.hpp" +#include "cpu/matmul/matmul_utils.hpp" + +#include "arm_compute/core/utils/quantization/AsymmHelpers.h" +#include "arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h" +#include "arm_compute/runtime/NEON/functions/NEQuantizationLayer.h" +#include "cpu/aarch64/acl_post_ops.hpp" +#include "cpu/aarch64/acl_utils.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { +namespace matmul { + +struct acl_lowp_matmul_sq_obj_t { + arm_compute::GEMMLowpOutputStageInfo info; + arm_compute::NEGEMMLowpMatrixMultiplyCore gemm; + arm_compute::Tensor src_tensor; + arm_compute::Tensor wei_tensor; + arm_compute::Tensor bia_tensor; + arm_compute::Tensor dst_tensor; +}; + +struct acl_lowp_matmul_sq_conf_t { + bool with_bias; + arm_compute::TensorInfo src_tensor_info; + arm_compute::TensorInfo wei_tensor_info; + arm_compute::TensorInfo bia_tensor_info; + arm_compute::TensorInfo dst_tensor_info; + arm_compute::GEMMInfo gemm_info; +}; + +struct acl_lowp_matmul_sq_resource_t : public resource_t { + acl_lowp_matmul_sq_resource_t() + : acl_obj_(utils::make_unique()) {} + + status_t configure(const acl_lowp_matmul_sq_conf_t &almc); + + acl_lowp_matmul_sq_obj_t &get_acl_obj() const { return *acl_obj_; } + + DNNL_DISALLOW_COPY_AND_ASSIGN(acl_lowp_matmul_sq_resource_t); + +private: + std::unique_ptr acl_obj_; +}; + +struct acl_lowp_matmul_sq_t : public primitive_t { + struct pd_t : public dnnl::impl::cpu::matmul::cpu_matmul_pd_t { + + pd_t(const matmul_desc_t *adesc, const primitive_attr_t *attr, + const cpu_matmul_pd_t *hint_fwd_pd) + : cpu_matmul_pd_t(adesc, attr, hint_fwd_pd), almc_() {} + + using cpu_matmul_pd_t::cpu_matmul_pd_t; + + DECLARE_COMMON_PD_T("lowp_gemm_sq:acl", acl_lowp_matmul_sq_t, + USE_GLOBAL_SCRATCHPAD); + + status_t init(engine_t *engine); + + status_t init_scratchpad(engine_t *engine, + memory_tracking::registrar_t &scratchpad, + acl_post_ops_t &post_ops, dnnl::impl::post_ops_t &attr_post_ops, + arm_compute::ActivationLayerInfo &act_info, + const dnnl::impl::memory_desc_t &dst_md); + + acl_lowp_matmul_sq_conf_t almc_; + acl_post_ops_t acl_post_ops; + }; + + acl_lowp_matmul_sq_t(const pd_t *apd) : primitive_t(apd) {} + + status_t create_resource(engine_t *engine, resource_mapper_t &mapper) const; + + status_t execute(const exec_ctx_t &ctx) const; + +private: + mutable std::mutex mtx; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } +}; + +} // namespace matmul +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif // CPU_AARCH64_ACL_LOWP_MATMUL_HPP \ No newline at end of file diff --git a/src/cpu/cpu_convolution_list.cpp b/src/cpu/cpu_convolution_list.cpp index bb02c045d8e..27d59fe8fac 100644 --- a/src/cpu/cpu_convolution_list.cpp +++ b/src/cpu/cpu_convolution_list.cpp @@ -490,6 +490,7 @@ const std::map> &impl_list_map() }}, {{forward, s8, s8, s8}, { CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) + CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t) CPU_INSTANCE_X64(ip_convolution_fwd_t) CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t) CPU_INSTANCE_AMX(brgemm_convolution_fwd_t) @@ -510,7 +511,6 @@ const std::map> &impl_list_map() CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t) CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t) CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t) CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) CPU_INSTANCE(ref_convolution_int8_fwd_t) CPU_INSTANCE(ref_fused_convolution_fwd_t) @@ -642,6 +642,7 @@ const std::map> &impl_list_map() nullptr, }}, {{forward, u8, s8, u8}, { + CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t) CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) CPU_INSTANCE_X64(ip_convolution_fwd_t) CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t) diff --git a/src/cpu/matmul/cpu_matmul_list.cpp b/src/cpu/matmul/cpu_matmul_list.cpp index 6a53d0920c6..e68015627c3 100644 --- a/src/cpu/matmul/cpu_matmul_list.cpp +++ b/src/cpu/matmul/cpu_matmul_list.cpp @@ -34,6 +34,7 @@ using namespace dnnl::impl::cpu::x64; #include "cpu/aarch64/matmul/brgemm_matmul.hpp" #ifdef DNNL_AARCH64_USE_ACL #include "cpu/aarch64/matmul/acl_lowp_matmul.hpp" +#include "cpu/aarch64/matmul/acl_lowp_matmul_sq.hpp" #include "cpu/aarch64/matmul/acl_matmul.hpp" #endif using namespace dnnl::impl::cpu::aarch64::matmul; @@ -73,10 +74,11 @@ using namespace dnnl::impl::cpu::matmul; // clang-format off constexpr impl_list_item_t impl_list[] = REG_MATMUL_P({ - CPU_INSTANCE_AARCH64(brgemm_matmul_t) + CPU_INSTANCE_AARCH64(brgemm_matmul_t) + CPU_INSTANCE_AARCH64_ACL(acl_lowp_matmul_sq_t) CPU_INSTANCE_AARCH64_ACL(acl_lowp_matmul_t) - CPU_INSTANCE_AARCH64_ACL(acl_matmul_t) - CPU_INSTANCE_AARCH64(brgemm_matmul_t) + CPU_INSTANCE_AARCH64_ACL(acl_matmul_t) + CPU_INSTANCE_AARCH64(brgemm_matmul_t) CPU_INSTANCE_AMX(brgemm_matmul_t) CPU_INSTANCE_AMX(brgemm_matmul_t) CPU_INSTANCE_AVX512(brgemm_matmul_t)