Skip to content

Commit

Permalink
[Relay] Port relay.backend.build to c++
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Lu committed Jan 16, 2020
1 parent 4eecd2a commit b889a8f
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 27 deletions.
27 changes: 0 additions & 27 deletions python/tvm/relay/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,33 +61,6 @@ def lower(sch, inputs, func_name, source_func):
f, (_container.Array, tuple, list)) else [f]


@register_func("relay.backend.build")
def build(funcs, target, target_host=None):
"""Backend build function.
Parameters
----------
funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
A list of lowered functions or dictionary mapping from targets to
lowered functions.
target : tvm.Target
The target to run the code on.
target_host : tvm.Target
The host target.
Returns
-------
module : tvm.Module
The runtime module.
"""
if target_host == "":
target_host = None
return _build.build(funcs, target=target, target_host=target_host)


@register_func("relay._tensor_value_repr")
def _tensor_value_repr(tvalue):
return str(tvalue.data.asnumpy())
Expand Down
47 changes: 47 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,53 @@ TVM_REGISTER_GLOBAL("relay.build_module._BuildModule")
*rv = RelayBuildCreate();
});

// Backend build function.
TVM_REGISTER_GLOBAL("relay.backend.build")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GT(args.size(), 0);

const bool allow_not_defined = true;
Target target = Target::Current(allow_not_defined);
if (args.size() > 1) {
target = args[1];
if (!target.defined()) {
target = target::llvm();
}
}

Map<Target, Array<LoweredFunc>> target_flist;
if (args[0].IsObjectRef<Array<LoweredFunc>>()) {
target_flist.Set(target, args[0]);
} else if (args[0].IsObjectRef<Map<Target, Array<LoweredFunc>>>()) {
target_flist = args[0];
} else if (args[0].IsObjectRef<Map<std::string, Array<LoweredFunc>>>()) {
Map<std::string, Array<LoweredFunc>> inputs = args[0];
for (const auto& it : inputs) {
auto tar = Target::Create(it.first);
target_flist.Set(tar, it.second);
}
} else {
LOG(ERROR) << "TypeError: Unsupported data type of args[0] "
<< runtime::TypeCode2Str(args[0].type_code());
}

Target target_host = [=]() -> Target {
if (args.size() == 3) {
return args[2];
}
for (const auto& it : target_flist) {
if (it.first->device_type == kDLCPU) {
return it.first;
}
}
const PackedFunc* pf = runtime::Registry::Get("module._Enabled");
CHECK(pf != nullptr);
return (*pf)("llvm") ? target::llvm() : target::stackvm();
}();

*rv = build(target_flist, target_host, BuildConfig::Current());
});

} // namespace backend
} // namespace relay
} // namespace tvm

0 comments on commit b889a8f

Please sign in to comment.