From ef0acee30aa4aa0ff8566c1327eff8e0eb702658 Mon Sep 17 00:00:00 2001 From: Hao Lu Date: Wed, 15 Jan 2020 17:58:16 -0800 Subject: [PATCH] [Relay] Port relay.backend.build to c++ --- python/tvm/relay/backend/_backend.py | 27 ------------- src/relay/backend/build_module.cc | 58 ++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 27 deletions(-) diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 860788a4e5d00..ff4e9e637fca1 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -61,33 +61,6 @@ def lower(sch, inputs, func_name, source_func): f, (_container.Array, tuple, list)) else [f] -@register_func("relay.backend.build") -def build(funcs, target, target_host=None): - """Backend build function. - - Parameters - ---------- - funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]] - A list of lowered functions or dictionary mapping from targets to - lowered functions. - - - target : tvm.Target - The target to run the code on. - - target_host : tvm.Target - The host target. - - Returns - ------- - module : tvm.Module - The runtime module. - """ - if target_host == "": - target_host = None - return _build.build(funcs, target=target, target_host=target_host) - - @register_func("relay._tensor_value_repr") def _tensor_value_repr(tvalue): return str(tvalue.data.asnumpy()) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 0458dfd55b179..3efe75830d0d9 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -520,6 +520,64 @@ TVM_REGISTER_GLOBAL("relay.build_module._BuildModule") *rv = RelayBuildCreate(); }); +// Backend build function. +TVM_REGISTER_GLOBAL("relay.backend.build") +.set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK_GT(args.size(), 0); + + const bool allow_not_defined = true; + Target target = Target::Current(allow_not_defined); + if (args.size() > 1) { + target = args[1]; + if (!target.defined()) { + target = target::llvm(); + } + } + + Map> target_flist; + if (args[0].IsObjectRef>()) { + target_flist.Set(target, args[0]); + } else if (args[0].IsObjectRef>>()) { + target_flist = args[0]; + } else if (args[0].IsObjectRef>>()) { + Map> inputs = args[0]; + for (const auto& it : inputs) { + auto tar = Target::Create(it.first); + target_flist.Set(tar, it.second); + } + } else { + LOG(ERROR) << "TypeError: Unsupported data type of args[0] " + << runtime::TypeCode2Str(args[0].type_code()); + } + + for (const auto& it : target_flist) { + std::unordered_set fname_set; + const auto& flist = it.second; + for (const auto& x : flist) { + if (fname_set.count(x->name)) { + LOG(ERROR) << "Duplicate function name " << x->name; + } + fname_set.insert(x->name); + } + } + + Target target_host = [=]() -> Target { + if (args.size() == 3) { + return args[2]; + } + for (const auto& it : target_flist) { + if (it.first->device_type == kDLCPU) { + return it.first; + } + } + const PackedFunc* pf = runtime::Registry::Get("module._Enabled"); + CHECK(pf != nullptr); + return (*pf)("llvm") ? target::llvm() : target::stackvm(); + }(); + + *rv = build(target_flist, target_host, BuildConfig::Current()); +}); + } // namespace backend } // namespace relay } // namespace tvm