Skip to content

Commit

Permalink
[Relay] C++ Build module
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Xu committed Apr 24, 2019
1 parent 3f835bd commit 6ea243a
Show file tree
Hide file tree
Showing 4 changed files with 763 additions and 2 deletions.
6 changes: 5 additions & 1 deletion include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,16 @@ TVM_DLL Array<LoweredFunc> lower(Schedule sch,
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \param (optional) returned host functions
* \param (optional) returned dev mods
* \return The built module.
*/
TVM_DLL runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config);
const BuildConfig& config,
Array<LoweredFunc>* fhost_ret = nullptr,
std::vector<runtime::Module>* devmod_ret = nullptr);

class GenericFuncNode;

Expand Down
13 changes: 12 additions & 1 deletion src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,9 @@ Array<LoweredFunc> lower(Schedule sch,
runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
const BuildConfig& config,
Array<LoweredFunc>* fhost_ret,
std::vector<runtime::Module>* devmod_ret) {
std::unordered_set<std::string> all_names;
for (const auto &x : funcs) {
CHECK(all_names.count(x->name) == 0) << "Duplicate function name " << x->name;
Expand Down Expand Up @@ -464,6 +466,12 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
}
}

if (fhost_ret != nullptr) {
for (auto f : fhost) {
fhost_ret->push_back(f);
}
}

auto keys = target->keys();
bool target_is_gpu =
std::find(keys.begin(), keys.end(), "gpu") != keys.end();
Expand Down Expand Up @@ -497,6 +505,9 @@ runtime::Module build(const Array<LoweredFunc>& funcs,

if (fdevice.size() > 0) {
auto mdev = codegen::Build(fdevice, target->str());
if (devmod_ret != nullptr) {
devmod_ret->push_back(mdev);
}
mhost.Import(mdev);
}

Expand Down
Loading

0 comments on commit 6ea243a

Please sign in to comment.