diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index a83288ce3662c..ed5138ed16b10 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -170,6 +170,9 @@ TVM_DLL Target intel_graphics(const std::vector& options = TVM_DLL Target stackvm(const std::vector& options = std::vector()); +/*! \return A target for ext_dev */ +TVM_DLL Target ext_dev(const std::vector& options = + std::vector()); } // namespace target /*! diff --git a/python/tvm/module.py b/python/tvm/module.py index d9676169cc5a2..976fb2d81cc73 100644 --- a/python/tvm/module.py +++ b/python/tvm/module.py @@ -274,6 +274,9 @@ def load(path, fmt=""): files = [tar_temp.relpath(x) for x in tar_temp.listdir()] _cc.create_shared(path + ".so", files) path += ".so" + # TODO(weberlo): we should probably use a more distinctive suffix for uTVM object files + elif path.endswith(".obj"): + fmt = "micro_dev" # Redirect to the load API return _LoadFromFile(path, fmt) diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 80fd57af66f97..a7325a92f50a7 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -309,6 +309,10 @@ Target intel_graphics(const std::vector& options) { Target stackvm(const std::vector& options) { return CreateTarget("stackvm", options); } + +Target ext_dev(const std::vector& options) { + return CreateTarget("ext_dev", options); +} } // namespace target bool LLVMEnabled() { diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 6d0fe581f9d2e..36139080682e0 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -73,12 +73,8 @@ struct GraphCodegen { return CallFunc("get_graph_json", nullptr); } - Array GetExternalFuncs() { - return CallFunc >("get_external_funcs", nullptr); - } - - runtime::Module GetExternalModule() { - return CallFunc("get_external_module", nullptr); + Array GetExternalModules() { + return CallFunc >("get_external_modules", nullptr); } Map > GetLoweredFunc() { @@ -156,13 +152,9 @@ class RelayBuildModule : public runtime::ModuleNode { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->graph_codegen_->GetLoweredFunc(); }); - } else if (name == "get_external_funcs") { + } else if (name == "get_external_modules") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->graph_codegen_->GetExternalFuncs(); - }); - } else if (name == "get_external_module") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->graph_codegen_->GetExternalModule(); + *rv = this->graph_codegen_->GetExternalModules(); }); } else if (name == "optimize") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -490,14 +482,18 @@ class RelayBuildModule : public runtime::ModuleNode { target_host_, BuildConfig::Current()); } - Array external_funcs = graph_codegen_->GetExternalFuncs(); - if (!external_funcs.empty()) { - auto ext_rt_mod = graph_codegen_->GetExternalModule(); - // Execute the whole module using external runtime. + Array ext_mods = graph_codegen_->GetExternalModules(); + if (!ext_mods.empty()) { + CHECK(lowered_funcs.size() > 0 || ext_mods.size() == 1) + << "Expect to have a TVM DSOModule when multiple external runtime modules exist"; if (lowered_funcs.size() == 0) { - ret_.mod = ext_rt_mod; + // Execute the whole module using external runtime. + ret_.mod = ext_mods[0]; } else { - ret_.mod.Import(ext_rt_mod); + // Import all external runtime modules. + for (const auto& it : ext_mods) { + ret_.mod.Import(it); + } } } } diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 083fa5d5610cd..3d99833fc3cd8 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -608,6 +609,46 @@ class CompileEngineImpl : public CompileEngineNode { return LowerShapeFuncInternal(key)->cached_func; } + Array LowerExternalFunctions() { + std::unordered_map ext_mods; + std::vector cached_ext_funcs; + for (const auto& it : cache_) { + auto src_func = it.first->source_func; + CHECK(src_func.defined()); + if (src_func->IsExternal()) { + auto compiler = FunctionGetAttr(src_func, attr::kExternal); + const tvm::ir::StringImm* code_gen = compiler.as(); + CHECK(code_gen) << "No external codegen is set"; + if (ext_mods.find(code_gen->value) == ext_mods.end()) { + ext_mods[code_gen->value] = relay::ModuleNode::make({}, {}); + } + auto ext_func_name = FunctionGetAttr(src_func, attr::kFuncName); + const tvm::ir::StringImm* func_name = ext_func_name.as(); + CHECK(func_name) << "No external function name is set for:\n" << AsText(src_func, false); + auto gv = GlobalVarNode::make(func_name->value); + ext_mods[code_gen->value]->Add(gv, src_func); + cached_ext_funcs.push_back(it.first); + } + } + + Array ret; + for (const auto& it : ext_mods) { + std::string ext_name = "relay.ext." + it.first; + auto pf = tvm::runtime::Registry::Get(ext_name); + CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n"; + runtime::Module ext_mod = (*pf)(it.second); + CHECK(ext_mod.defined()) << "No external runtime is generated."; + ret.push_back(ext_mod); + } + + // No need to cache external functions as we collected them all to create + // external runtime modules. + for (const auto& it : cached_ext_funcs) { + cache_.erase(it); + } + return ret; + } + void Clear() final { cache_.clear(); } @@ -648,6 +689,18 @@ class CompileEngineImpl : public CompileEngineNode { value->use_count = 0; cache_[key] = value; } + // No need to lower external function for now. We will invoke the external + // codegen tool once and lower all functions together. + if (key->source_func->IsExternal()) { + auto cache_node = make_node(); + const auto name_node = + FunctionGetAttr(key->source_func, attr::kFuncName).as(); + CHECK(name_node != nullptr) << "External function has not been attached a name yet."; + cache_node->func_name = name_node->value; + cache_node->target = tvm::target::ext_dev(); + value->cached_func = CachedFunc(cache_node); + return value; + } // Enforce use the target. With target_scope(key->target); @@ -759,42 +812,46 @@ const CompileEngine& CompileEngine::Global() { return *inst; } - TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") .set_body_typed(CCacheKeyNode::make); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal") .set_body_typed([]() { - return CompileEngine::Global(); - }); + return CompileEngine::Global(); +}); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear") .set_body_typed([](CompileEngine self) { - self->Clear(); - }); + self->Clear(); +}); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") .set_body_typed( [](CompileEngine self, CCacheKey key) { - return self->Lower(key); - }); + return self->Lower(key); +}); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") .set_body_typed( [](CompileEngine self, CCacheKey key) { - return self->LowerShapeFunc(key); - }); + return self->LowerShapeFunc(key); +}); + +TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions") +.set_body_typed([](CompileEngine self) { + return self->LowerExternalFunctions(); +}); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") .set_body_typed( [](CompileEngine self, CCacheKey key) { - return self->JIT(key); - }); + return self->JIT(key); +}); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems") .set_body_typed(CompileEngine)>( [](CompileEngine self){ - return static_cast(self.operator->())->ListItems(); - }); + return static_cast(self.operator->())->ListItems(); +}); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 31e246ecf1fe9..596dfa7154f7f 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -26,6 +26,7 @@ #define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ #include +#include #include #include #include @@ -186,6 +187,12 @@ class CompileEngineNode : public Node { * \return The result. */ virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0; + /*! + * \brief Lower the external function using external codegen tools. + * \return The runtime moduels for each needed external codegen tool. + */ + virtual tvm::Array LowerExternalFunctions() = 0; + /*! \brief clear the cache. */ virtual void Clear() = 0; diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index cf5f26fedfa72..2a27e66104054 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -56,7 +56,7 @@ using TargetsMap = std::unordered_map; struct LoweredOutput { std::string graph_json; Map > lowered_funcs; - Array external_funcs; + Array external_mods; std::unordered_map params; }; @@ -214,7 +214,6 @@ class GraphRuntimeCodegen LoweredOutput ret; ret.graph_json = os.str(); ret.params = params_; - ret.external_funcs = external_funcs_; for (auto& kv : lowered_funcs_) { if (ret.lowered_funcs.count(kv.first) == 0) { ret.lowered_funcs.Set(kv.first, Array()); @@ -229,6 +228,7 @@ class GraphRuntimeCodegen } ret.lowered_funcs.Set(kv.first, tmp); } + ret.external_mods = compile_engine_->LowerExternalFunctions(); return ret; } @@ -384,8 +384,8 @@ class GraphRuntimeCodegen return fields; } - std::vector InvokeExternalCodegen(const CallNode* op, const Function& func) { - CHECK(func->IsExternal()); + std::vector GraphAddCallNode(const CallNode* op, + const std::string& op_name) { std::vector inputs; for (auto arg : op->args) { auto res = VisitExpr(arg); @@ -393,11 +393,7 @@ class GraphRuntimeCodegen inputs.push_back(nr); } } - external_funcs_.push_back(func); - const auto name_node = FunctionGetAttr(func, attr::kFuncName).as(); - CHECK(name_node != nullptr) << "External function has not been attached a name yet."; - std::string op_name = name_node->value; - auto node = GraphOpNode::make_node_ptr(_GetUniqueName(op_name), + auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), op_name, inputs, @@ -415,9 +411,6 @@ class GraphRuntimeCodegen LOG(FATAL) << "Not implemented"; } else if (op->op.as()) { func = GetRef(op->op.as()); - if (func->IsExternal()) { - return InvokeExternalCodegen(op, func); - } } else { LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); } @@ -426,17 +419,26 @@ class GraphRuntimeCodegen << "(i.e functions composed of fusable operator invocations)"; } - CHECK_GE(storage_device_map_.count(expr), 0); auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); + Target target; + // Handle external function + if (func->IsExternal()) { + target = tvm::target::ext_dev(); + CCacheKey key = (*pf0)(func, target); + CachedFunc ext_func = (*pf1)(compile_engine_, key); + CHECK(ext_func.defined()) << "External function is not defined."; + return GraphAddCallNode(op, ext_func->func_name); + } + + CHECK_GE(storage_device_map_.count(expr), 0); auto &device_type = storage_device_map_[expr][1]; auto call_dev_type = device_type[0]->value; - Target target; + // Normal Relay Function if (targets_.size() == 1) { // homogeneous execution. - for (auto kv : targets_) { - target = kv.second; - } + const auto& it = targets_.begin(); + target = (*it).second; } else { // heterogeneous execution. std::string call_dev_name; @@ -460,20 +462,7 @@ class GraphRuntimeCodegen lowered_funcs_[target->str()].insert(f); } - std::vector inputs; - for (auto arg : op->args) { - auto res = VisitExpr(arg); - for (auto nr : res) { - inputs.push_back(nr); - } - } - auto& op_name = lowerd_func->func_name; - auto node = GraphOpNode::make_node_ptr(_GetUniqueName(op_name), - GraphAttrs(), - op_name, - inputs, - GraphAttrs()); - return AddNode(node, expr); + return GraphAddCallNode(op, _GetUniqueName(lowerd_func->func_name)); } std::vector VisitExpr_(const LetNode* op) override { @@ -615,8 +604,6 @@ class GraphRuntimeCodegen std::unordered_map name_map_; /*! \brief compile engine */ CompileEngine compile_engine_; - /*! \brief external functions */ - Array external_funcs_; }; class GraphRuntimeCodegenModule : public runtime::ModuleNode { @@ -668,34 +655,9 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.lowered_funcs; }); - } else if (name == "get_external_funcs") { + } else if (name == "get_external_modules") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->output_.external_funcs; - }); - } else if (name == "get_external_module") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK(!this->output_.external_funcs.empty()) << "No external function is annotated."; - // Invoke the external codegen to generate a external runtime module. - auto compiler = FunctionGetAttr(output_.external_funcs[0], attr::kExternal); - const tvm::ir::StringImm* code_gen = compiler.as(); - CHECK(code_gen) << "No external codegen is set"; - std::string ext_name = "relay.ext." + code_gen->value; - auto pf = tvm::runtime::Registry::Get(ext_name); - CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n"; - - // Invoke the 3rd party codegen to generate a library for the external - // functions. - relay::Module rly_mod = relay::ModuleNode::make({}, {}); - for (const auto& func : output_.external_funcs) { - auto ext_func_name = FunctionGetAttr(func, attr::kFuncName); - const tvm::ir::StringImm* func_name = ext_func_name.as(); - CHECK(func_name) << "No external function name is set for:\n" << AsText(func, false); - auto gv = GlobalVarNode::make(func_name->value); - rly_mod->Add(gv, func); - } - runtime::Module ext_mod = (*pf)(rly_mod); - CHECK(ext_mod.defined()) << "No external runtime is generated."; - *rv = ext_mod; + *rv = this->output_.external_mods; }); } else { return PackedFunc([](TVMArgs args, TVMRetValue* rv) {}); diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 3673e9d4449b2..440e1c1d6c731 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -185,7 +185,7 @@ TVM_REGISTER_API("relay._expr.FunctionGetParams") bool FunctionNode::IsExternal() const { NodeRef res = FunctionGetAttr(GetRef(this), attr::kExternal); const ir::StringImm* pval = res.as(); - return pval; + return pval != nullptr; } NodeRef FunctionGetAttr(const Function& func, const std::string& key) { diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 0eb1fa16885f0..6a4c811e40763 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -25,7 +25,7 @@ from tvm import relay from tvm.contrib import util -def check_result(mod, map_inputs, out_shape, result, tol=1e-7): +def check_result(mod, map_inputs, out_shape, result, tol=1e-5): with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): json, lib, _ = relay.build(mod, "llvm") test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) @@ -172,6 +172,7 @@ def test_extern_gcc(): check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data)) +@pytest.mark.skip(reason="Only for DEMO purpose, need to have dnnl for usage") def test_extern_dnnl(): dtype = 'float32' ishape = (1, 32, 14, 14) @@ -215,4 +216,4 @@ def test_extern_dnnl(): test_multi_node_subgraph() test_extern_gcc_single_op() test_extern_gcc() - test_extern_dnnl() + # test_extern_dnnl()