Skip to content

Commit

Permalink
move external function codegen to compile_engine
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Dec 9, 2019
1 parent 4c9761b commit a52e18b
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 94 deletions.
3 changes: 3 additions & 0 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ TVM_DLL Target intel_graphics(const std::vector<std::string>& options =
TVM_DLL Target stackvm(const std::vector<std::string>& options =
std::vector<std::string>());

/*! \return A target for ext_dev */
TVM_DLL Target ext_dev(const std::vector<std::string>& options =
std::vector<std::string>());
} // namespace target

/*!
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ def load(path, fmt=""):
files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
_cc.create_shared(path + ".so", files)
path += ".so"
# TODO(weberlo): we should probably use a more distinctive suffix for uTVM object files
elif path.endswith(".obj"):
fmt = "micro_dev"
# Redirect to the load API
return _LoadFromFile(path, fmt)

Expand Down
4 changes: 4 additions & 0 deletions src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ Target intel_graphics(const std::vector<std::string>& options) {
Target stackvm(const std::vector<std::string>& options) {
return CreateTarget("stackvm", options);
}

Target ext_dev(const std::vector<std::string>& options) {
return CreateTarget("ext_dev", options);
}
} // namespace target

bool LLVMEnabled() {
Expand Down
32 changes: 14 additions & 18 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,8 @@ struct GraphCodegen {
return CallFunc<std::string>("get_graph_json", nullptr);
}

Array<Function> GetExternalFuncs() {
return CallFunc<Array<Function> >("get_external_funcs", nullptr);
}

runtime::Module GetExternalModule() {
return CallFunc<runtime::Module>("get_external_module", nullptr);
Array<tvm::runtime::Module> GetExternalModules() {
return CallFunc<Array<tvm::runtime::Module> >("get_external_modules", nullptr);
}

Map<std::string, Array<LoweredFunc> > GetLoweredFunc() {
Expand Down Expand Up @@ -156,13 +152,9 @@ class RelayBuildModule : public runtime::ModuleNode {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->graph_codegen_->GetLoweredFunc();
});
} else if (name == "get_external_funcs") {
} else if (name == "get_external_modules") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->graph_codegen_->GetExternalFuncs();
});
} else if (name == "get_external_module") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->graph_codegen_->GetExternalModule();
*rv = this->graph_codegen_->GetExternalModules();
});
} else if (name == "optimize") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Expand Down Expand Up @@ -490,14 +482,18 @@ class RelayBuildModule : public runtime::ModuleNode {
target_host_,
BuildConfig::Current());
}
Array<Function> external_funcs = graph_codegen_->GetExternalFuncs();
if (!external_funcs.empty()) {
auto ext_rt_mod = graph_codegen_->GetExternalModule();
// Execute the whole module using external runtime.
Array<tvm::runtime::Module> ext_mods = graph_codegen_->GetExternalModules();
if (!ext_mods.empty()) {
CHECK(lowered_funcs.size() > 0 || ext_mods.size() == 1)
<< "Expect to have a TVM DSOModule when multiple external runtime modules exist";
if (lowered_funcs.size() == 0) {
ret_.mod = ext_rt_mod;
// Execute the whole module using external runtime.
ret_.mod = ext_mods[0];
} else {
ret_.mod.Import(ext_rt_mod);
// Import all external runtime modules.
for (const auto& it : ext_mods) {
ret_.mod.Import(it);
}
}
}
}
Expand Down
83 changes: 70 additions & 13 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/tags.h>
Expand Down Expand Up @@ -608,6 +609,46 @@ class CompileEngineImpl : public CompileEngineNode {
return LowerShapeFuncInternal(key)->cached_func;
}

Array<tvm::runtime::Module> LowerExternalFunctions() {
std::unordered_map<std::string, relay::Module> ext_mods;
std::vector<CCacheKey> cached_ext_funcs;
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
CHECK(src_func.defined());
if (src_func->IsExternal()) {
auto compiler = FunctionGetAttr(src_func, attr::kExternal);
const tvm::ir::StringImm* code_gen = compiler.as<tvm::ir::StringImm>();
CHECK(code_gen) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) {
ext_mods[code_gen->value] = relay::ModuleNode::make({}, {});
}
auto ext_func_name = FunctionGetAttr(src_func, attr::kFuncName);
const tvm::ir::StringImm* func_name = ext_func_name.as<tvm::ir::StringImm>();
CHECK(func_name) << "No external function name is set for:\n" << AsText(src_func, false);
auto gv = GlobalVarNode::make(func_name->value);
ext_mods[code_gen->value]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first);
}
}

Array<tvm::runtime::Module> ret;
for (const auto& it : ext_mods) {
std::string ext_name = "relay.ext." + it.first;
auto pf = tvm::runtime::Registry::Get(ext_name);
CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
runtime::Module ext_mod = (*pf)(it.second);
CHECK(ext_mod.defined()) << "No external runtime is generated.";
ret.push_back(ext_mod);
}

// No need to cache external functions as we collected them all to create
// external runtime modules.
for (const auto& it : cached_ext_funcs) {
cache_.erase(it);
}
return ret;
}

void Clear() final {
cache_.clear();
}
Expand Down Expand Up @@ -648,6 +689,18 @@ class CompileEngineImpl : public CompileEngineNode {
value->use_count = 0;
cache_[key] = value;
}
// No need to lower external function for now. We will invoke the external
// codegen tool once and lower all functions together.
if (key->source_func->IsExternal()) {
auto cache_node = make_node<CachedFuncNode>();
const auto name_node =
FunctionGetAttr(key->source_func, attr::kFuncName).as<tvm::ir::StringImm>();
CHECK(name_node != nullptr) << "External function has not been attached a name yet.";
cache_node->func_name = name_node->value;
cache_node->target = tvm::target::ext_dev();
value->cached_func = CachedFunc(cache_node);
return value;
}
// Enforce use the target.
With<Target> target_scope(key->target);

Expand Down Expand Up @@ -759,42 +812,46 @@ const CompileEngine& CompileEngine::Global() {
return *inst;
}


TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
.set_body_typed<CCacheKey(Function, Target)>(CCacheKeyNode::make);

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal")
.set_body_typed<CompileEngine()>([]() {
return CompileEngine::Global();
});
return CompileEngine::Global();
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear")
.set_body_typed<void(const CompileEngine&)>([](CompileEngine self) {
self->Clear();
});
self->Clear();
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower")
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->Lower(key);
});
return self->Lower(key);
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc")
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->LowerShapeFunc(key);
});
return self->LowerShapeFunc(key);
});

TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions")
.set_body_typed<void(const CompileEngine&)>([](CompileEngine self) {
return self->LowerExternalFunctions();
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT")
.set_body_typed<PackedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->JIT(key);
});
return self->JIT(key);
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems")
.set_body_typed<Array<NodeRef>(CompileEngine)>(
[](CompileEngine self){
return static_cast<CompileEngineImpl*>(self.operator->())->ListItems();
});
return static_cast<CompileEngineImpl*>(self.operator->())->ListItems();
});
} // namespace relay
} // namespace tvm
7 changes: 7 additions & 0 deletions src/relay/backend/compile_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_

#include <tvm/lowered_func.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
Expand Down Expand Up @@ -186,6 +187,12 @@ class CompileEngineNode : public Node {
* \return The result.
*/
virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0;
/*!
* \brief Lower the external function using external codegen tools.
* \return The runtime moduels for each needed external codegen tool.
*/
virtual tvm::Array<tvm::runtime::Module> LowerExternalFunctions() = 0;

/*! \brief clear the cache. */
virtual void Clear() = 0;

Expand Down
Loading

0 comments on commit a52e18b

Please sign in to comment.