Skip to content

Commit

Permalink
asdf
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Xu committed May 7, 2019
1 parent 6ea243a commit 05c828d
Show file tree
Hide file tree
Showing 6 changed files with 407 additions and 206 deletions.
19 changes: 14 additions & 5 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,23 +344,32 @@ TVM_DLL Array<LoweredFunc> lower(Schedule sch,
const std::string& name,
const std::unordered_map<Tensor, Buffer>& binds,
const BuildConfig& config);
/*!
* \brief Split host/device function and running necessary pass before build
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \return The Array<Array<LoweredFunc>> with 2 elements. First is host function Array,
second is device function array
*/
TVM_DLL Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config);

/*!
* \brief Build a device and host module for a specific target from an array of lowered functions.
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \param (optional) returned host functions
* \param (optional) returned dev mods
* \return The built module.
*/
TVM_DLL runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config,
Array<LoweredFunc>* fhost_ret = nullptr,
std::vector<runtime::Module>* devmod_ret = nullptr);
const BuildConfig& config);

class GenericFuncNode;

Expand Down
33 changes: 18 additions & 15 deletions src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,12 +422,10 @@ Array<LoweredFunc> lower(Schedule sch,
return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) });
}

runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config,
Array<LoweredFunc>* fhost_ret,
std::vector<runtime::Module>* devmod_ret) {
Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
std::unordered_set<std::string> all_names;
for (const auto &x : funcs) {
CHECK(all_names.count(x->name) == 0) << "Duplicate function name " << x->name;
Expand Down Expand Up @@ -466,12 +464,6 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
}
}

if (fhost_ret != nullptr) {
for (auto f : fhost) {
fhost_ret->push_back(f);
}
}

auto keys = target->keys();
bool target_is_gpu =
std::find(keys.begin(), keys.end(), "gpu") != keys.end();
Expand Down Expand Up @@ -500,14 +492,25 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
func = ir::CombineContextCall(func);
fhost.Set(i, func);
}
Array<Array<LoweredFunc> > ret;
ret.push_back(fhost);
ret.push_back(fdevice);
return ret;
}

runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
auto target_host_val = target_host.defined() ? target_host : DefaultTargetHost(target);
auto host_dev_funcs = split_dev_host_funcs(funcs, target, target_host, config);
auto& fhost = host_dev_funcs[0];
auto& fdevice = host_dev_funcs[1];

auto mhost = codegen::Build(fhost, target_host_val->str());

if (fdevice.size() > 0) {
auto mdev = codegen::Build(fdevice, target->str());
if (devmod_ret != nullptr) {
devmod_ret->push_back(mdev);
}
mhost.Import(mdev);
}

Expand Down
Loading

0 comments on commit 05c828d

Please sign in to comment.