From d48ae14dad71adf876a72cd63af944dc4fed68c8 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Thu, 29 Oct 2020 16:06:53 +0000 Subject: [PATCH] Create UpdateConstants utility function Change-Id: I73c8c6f92cfe3be3a7e811e98a4eec17590186ff --- src/relay/backend/graph_runtime_codegen.cc | 27 +------------------ src/relay/backend/utils.h | 31 ++++++++++++++++++++++ src/relay/backend/vm/compiler.cc | 21 +-------------- 3 files changed, 33 insertions(+), 46 deletions(-) diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 655fdf013d1ea..e24d18de931c8 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -368,32 +368,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorGetAttr(attr::kCompiler); - ICHECK(codegen.defined()) << "No external codegen is set"; - std::string codegen_name = codegen.value(); - const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); - std::string symbol = std::string(name_node.value()); - std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater"; - // Get the constant updater for the external codegen - auto pf = tvm::runtime::Registry::Get(const_update_name); - // If the backend hasn't registered a constant updater, use a default one - if (pf == nullptr) { - ConstantUpdater const_visit(symbol, ¶ms_); - const_visit(func); - } else { - Map constants = (*pf)(func, symbol); - for (const auto& it : constants) { - std::string const_name(it.first); - // Constant names should begin this the compiler name (to avoid conflicts) - ICHECK(const_name.find(codegen_name) == 0) - << "External constant names must start with compiler name"; - params_[const_name] = it.second; - } - } - + UpdateConstants(func, ¶ms_); return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name); } diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 3def6359c6153..4426642e8e18b 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -63,6 +63,37 @@ struct ConstantUpdater : public ExprVisitor { std::unordered_map* params_; }; +/*! + * \brief A function to update the params with constants found in an external function. + * \param func The function from which to get the constant params. + * \param params The params to update with the constants. + */ +inline void UpdateConstants(Function func, + std::unordered_map* params) { + auto codegen = func->GetAttr(attr::kCompiler); + ICHECK(codegen.defined()) << "No external codegen is set"; + std::string codegen_name = codegen.value(); + const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); + std::string symbol = std::string(name_node.value()); + std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater"; + // Get the constant updater for the external codegen + auto pf = tvm::runtime::Registry::Get(const_update_name); + // If the backend hasn't registered a constant updater, use a default one + if (pf == nullptr) { + ConstantUpdater const_visit(symbol, params); + const_visit(func); + } else { + Map constants = (*pf)(func, symbol); + for (const auto& it : constants) { + std::string const_name(it.first); + // Constant names should begin this the compiler name (to avoid conflicts) + ICHECK(const_name.find(codegen_name) == 0) + << "External constant names must start with compiler name"; + (*params)[const_name] = it.second; + } + } +} + /*! * \brief A simple wrapper around ExprFunctor for a single argument case. * The result of visit is memoized. diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 8b8a38f00e845..f652644afa3c3 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1125,26 +1125,7 @@ void VMCompiler::Codegen() { // Collect metadata in functions that are handled by external codegen. ICHECK(mod->ContainGlobalVar(cfunc->func_name)); Function func = Downcast(mod->Lookup(cfunc->func_name)); - auto codegen = func->GetAttr(attr::kCompiler); - ICHECK(codegen.defined()) << "No external codegen is set"; - std::string codegen_name = codegen.value(); - std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater"; - // Get the constant updater for the external codegen - auto pf = tvm::runtime::Registry::Get(const_update_name); - // If the backend hasn't registered a constant updater, use a default one - if (pf == nullptr) { - backend::ConstantUpdater const_visit(cfunc->func_name, ¶ms_); - const_visit(func); - } else { - Map constants = (*pf)(func, cfunc->func_name); - for (const auto& it : constants) { - std::string const_name(it.first); - // Constant names should begin this the compiler name (to avoid conflicts) - ICHECK(const_name.find(codegen_name) == 0) - << "External constant names must start with compiler name"; - params_[const_name] = it.second; - } - } + backend::UpdateConstants(func, ¶ms_); continue; } else if (funcs.count(target_str) == 0) { funcs.emplace(target_str, mod);