Skip to content

Commit

Permalink
Create UpdateConstants utility function
Browse files Browse the repository at this point in the history
Change-Id: I73c8c6f92cfe3be3a7e811e98a4eec17590186ff
  • Loading branch information
mbaret committed Oct 29, 2020
1 parent a787166 commit d48ae14
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 46 deletions.
27 changes: 1 addition & 26 deletions src/relay/backend/graph_runtime_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,32 +368,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<G
CCacheKey key = (*pf0)(func, target);
CachedFunc ext_func = (*pf1)(compile_engine_, key);
ICHECK(ext_func.defined()) << "External function is not defined.";

// Step into the functions that are handled by external codegen to
// collect metadata.
auto codegen = func->GetAttr<String>(attr::kCompiler);
ICHECK(codegen.defined()) << "No external codegen is set";
std::string codegen_name = codegen.value();
const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
std::string symbol = std::string(name_node.value());
std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater";
// Get the constant updater for the external codegen
auto pf = tvm::runtime::Registry::Get(const_update_name);
// If the backend hasn't registered a constant updater, use a default one
if (pf == nullptr) {
ConstantUpdater const_visit(symbol, &params_);
const_visit(func);
} else {
Map<String, tvm::runtime::NDArray> constants = (*pf)(func, symbol);
for (const auto& it : constants) {
std::string const_name(it.first);
// Constant names should begin this the compiler name (to avoid conflicts)
ICHECK(const_name.find(codegen_name) == 0)
<< "External constant names must start with compiler name";
params_[const_name] = it.second;
}
}

UpdateConstants(func, &params_);
return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name);
}

Expand Down
31 changes: 31 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,37 @@ struct ConstantUpdater : public ExprVisitor {
std::unordered_map<std::string, runtime::NDArray>* params_;
};

/*!
* \brief A function to update the params with constants found in an external function.
* \param func The function from which to get the constant params.
* \param params The params to update with the constants.
*/
inline void UpdateConstants(Function func,
std::unordered_map<std::string, runtime::NDArray>* params) {
auto codegen = func->GetAttr<String>(attr::kCompiler);
ICHECK(codegen.defined()) << "No external codegen is set";
std::string codegen_name = codegen.value();
const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
std::string symbol = std::string(name_node.value());
std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater";
// Get the constant updater for the external codegen
auto pf = tvm::runtime::Registry::Get(const_update_name);
// If the backend hasn't registered a constant updater, use a default one
if (pf == nullptr) {
ConstantUpdater const_visit(symbol, params);
const_visit(func);
} else {
Map<String, tvm::runtime::NDArray> constants = (*pf)(func, symbol);
for (const auto& it : constants) {
std::string const_name(it.first);
// Constant names should begin this the compiler name (to avoid conflicts)
ICHECK(const_name.find(codegen_name) == 0)
<< "External constant names must start with compiler name";
(*params)[const_name] = it.second;
}
}
}

/*!
* \brief A simple wrapper around ExprFunctor for a single argument case.
* The result of visit is memoized.
Expand Down
21 changes: 1 addition & 20 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1125,26 +1125,7 @@ void VMCompiler::Codegen() {
// Collect metadata in functions that are handled by external codegen.
ICHECK(mod->ContainGlobalVar(cfunc->func_name));
Function func = Downcast<Function>(mod->Lookup(cfunc->func_name));
auto codegen = func->GetAttr<String>(attr::kCompiler);
ICHECK(codegen.defined()) << "No external codegen is set";
std::string codegen_name = codegen.value();
std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater";
// Get the constant updater for the external codegen
auto pf = tvm::runtime::Registry::Get(const_update_name);
// If the backend hasn't registered a constant updater, use a default one
if (pf == nullptr) {
backend::ConstantUpdater const_visit(cfunc->func_name, &params_);
const_visit(func);
} else {
Map<String, tvm::runtime::NDArray> constants = (*pf)(func, cfunc->func_name);
for (const auto& it : constants) {
std::string const_name(it.first);
// Constant names should begin this the compiler name (to avoid conflicts)
ICHECK(const_name.find(codegen_name) == 0)
<< "External constant names must start with compiler name";
params_[const_name] = it.second;
}
}
backend::UpdateConstants(func, &params_);
continue;
} else if (funcs.count(target_str) == 0) {
funcs.emplace(target_str, mod);
Expand Down

0 comments on commit d48ae14

Please sign in to comment.