From edda830af0bb50edf03477882f9dee4888992dad Mon Sep 17 00:00:00 2001 From: Leo-arm <52416576+Leo-arm@users.noreply.github.com> Date: Fri, 22 Oct 2021 11:15:50 +0100 Subject: [PATCH] [ETHOSN] Match config for is-supported with compilation target (#9160) The Ethos-N variant configuration for the is-supported functionality is now the same as the variant configuration for the actual compilation --- src/relay/backend/contrib/ethosn/codegen.cc | 73 +++++++++++++------ .../backend/contrib/ethosn/codegen_ethosn.h | 18 +++++ 2 files changed, 69 insertions(+), 22 deletions(-) diff --git a/src/relay/backend/contrib/ethosn/codegen.cc b/src/relay/backend/contrib/ethosn/codegen.cc index 97b308e51e18..3e675215e7e0 100644 --- a/src/relay/backend/contrib/ethosn/codegen.cc +++ b/src/relay/backend/contrib/ethosn/codegen.cc @@ -606,25 +606,37 @@ std::pair, std::vector> EthosnCompiler::GetInput return std::make_pair(input_order, output_order); } -auto ctx = transform::PassContext::Current(); -auto cfg = ctx -> GetConfig("relay.ext.ethos-n.options").defined() - ? ctx -> GetConfig("relay.ext.ethos-n.options") - : AttrsWithDefaultValues(); -auto m_Queries = sl::SupportQueries(sl::GetFwAndHwCapabilities( - sl::EthosNVariantFromString(cfg.value()->variant.c_str()), cfg.value()->sram_size_bytes)); +std::unique_ptr EthosnCompiler::m_Queries; + +EthosnError EthosnCompiler::SupportedSetup() { + if (m_Queries == nullptr) { + auto ctx = transform::PassContext::Current(); + auto cfg = ctx->GetConfig("relay.ext.ethos-n.options").defined() + ? ctx->GetConfig("relay.ext.ethos-n.options") + : AttrsWithDefaultValues(); + m_Queries = std::make_unique(sl::GetFwAndHwCapabilities( + sl::EthosNVariantFromString(cfg.value()->variant.c_str()), cfg.value()->sram_size_bytes)); + if (m_Queries == nullptr) { + return EthosnError("Could not initialise Ethos-N compiler isSupported"); + } + } + return EthosnError(); +} TVM_REGISTER_GLOBAL("relay.ethos-n.support.conv2d") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; ConvolutionParams params; auto err = EthosnAPI::QnnConv2d(call, ¶ms); + err += EthosnCompiler::SupportedSetup(); if (params.is_depthwise) { *rv = !err && - m_Queries.IsDepthwiseConvolutionSupported(params.bias_info, params.weights_info, - params.conv_info, params.activation_info); + EthosnCompiler::GetSupported()->IsDepthwiseConvolutionSupported( + params.bias_info, params.weights_info, params.conv_info, params.activation_info); } else { - *rv = !err && m_Queries.IsConvolutionSupported(params.bias_info, params.weights_info, - params.conv_info, params.activation_info); + *rv = !err && + EthosnCompiler::GetSupported()->IsConvolutionSupported( + params.bias_info, params.weights_info, params.conv_info, params.activation_info); } }); @@ -633,8 +645,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.fc") Call call = args[0]; FullyConnectedParams params; auto err = EthosnAPI::QnnFullyConnected(call, ¶ms); - *rv = !err && m_Queries.IsFullyConnectedSupported(params.bias_info, params.weights_info, - params.fc_info, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsFullyConnectedSupported( + params.bias_info, params.weights_info, params.fc_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d") @@ -642,7 +655,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d") Call call = args[0]; MaxPool2DParams params; auto err = EthosnAPI::MaxPool2D(call, ¶ms); - *rv = !err && m_Queries.IsPoolingSupported(params.pool_info, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && + EthosnCompiler::GetSupported()->IsPoolingSupported(params.pool_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d") @@ -650,7 +665,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d") Call call = args[0]; AvgPool2DParams params; auto err = EthosnAPI::AvgPool2D(call, ¶ms); - *rv = !err && m_Queries.IsPoolingSupported(params.pool_info, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && + EthosnCompiler::GetSupported()->IsPoolingSupported(params.pool_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape") @@ -658,7 +675,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape") Call call = args[0]; ReshapeParams params; auto err = EthosnAPI::Reshape(call, ¶ms); - *rv = !err && m_Queries.IsReshapeSupported(params.new_shape, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && + EthosnCompiler::GetSupported()->IsReshapeSupported(params.new_shape, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition") @@ -666,8 +685,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition") Call call = args[0]; AdditionParams params; auto err = EthosnAPI::Addition(call, ¶ms); - *rv = !err && m_Queries.IsAdditionSupported(params.lhs_info, params.rhs_info, - params.output_quantization_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsAdditionSupported( + params.lhs_info, params.rhs_info, params.output_quantization_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid") @@ -675,7 +695,8 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid") Call call = args[0]; SigmoidParams params; auto err = EthosnAPI::Sigmoid(call, ¶ms); - *rv = !err && m_Queries.IsSigmoidSupported(params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsSigmoidSupported(params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate") @@ -683,7 +704,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate") Call call = args[0]; ConcatenateParams params; auto err = EthosnAPI::Concatenate(call, ¶ms); - *rv = !err && m_Queries.IsConcatenationSupported(params.input_infos, params.concat_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsConcatenationSupported(params.input_infos, + params.concat_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.split") @@ -691,7 +714,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.split") Call call = args[0]; SplitParams params; auto err = EthosnAPI::Split(call, ¶ms); - *rv = !err && m_Queries.IsSplitSupported(params.input_info, params.split_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && + EthosnCompiler::GetSupported()->IsSplitSupported(params.input_info, params.split_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space") @@ -699,7 +724,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space") Call call = args[0]; DepthToSpaceParams params; auto err = EthosnAPI::DepthToSpace(call, ¶ms); - *rv = !err && m_Queries.IsDepthToSpaceSupported(params.input_info, params.depth_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsDepthToSpaceSupported(params.input_info, + params.depth_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu") @@ -707,7 +734,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu") Call call = args[0]; ReluParams params; auto err = EthosnAPI::Relu(call, ¶ms); - *rv = !err && m_Queries.IsReluSupported(params.relu_info, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && + EthosnCompiler::GetSupported()->IsReluSupported(params.relu_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.query").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { diff --git a/src/relay/backend/contrib/ethosn/codegen_ethosn.h b/src/relay/backend/contrib/ethosn/codegen_ethosn.h index 63ae7a3e4704..ca2df05e958d 100644 --- a/src/relay/backend/contrib/ethosn/codegen_ethosn.h +++ b/src/relay/backend/contrib/ethosn/codegen_ethosn.h @@ -287,6 +287,22 @@ class EthosnCompiler { */ static runtime::Module CreateRuntimeModule(const ObjectRef& ref); + /*! + * \brief Initialise the is-supported functionality of the Ethos-N support library + * with the target variant. + * \return Error object + */ + static EthosnError SupportedSetup(); + + /*! + * \brief Return the is-supported API of the Support Library + * \return A reference to the API. + */ + static std::unique_ptr& GetSupported() { + ICHECK(m_Queries != nullptr); + return m_Queries; + } + private: /*! * \brief Compile a single Relay Ethos-N function into an ordered compiled network. @@ -322,6 +338,8 @@ class EthosnCompiler { */ static std::pair, std::vector> GetInputOutputOrder( NetworkWithIDs network, const std::unique_ptr& compiled_network); + + static std::unique_ptr m_Queries; }; runtime::Module CompileEthosn(const ObjectRef& ref) {