Skip to content

Commit

Permalink
[codegen] heterogeneous build for c++ (apache#3144)
Browse files Browse the repository at this point in the history
* heterogeneous build for c++

* merge relay buildmodule to codegen build

* use module split

* use target_host

* remove sse3

* retrigger ci
  • Loading branch information
zhiics authored and wweic committed May 13, 2019
1 parent 4d2f0ca commit 4608e10
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 78 deletions.
29 changes: 29 additions & 0 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,35 @@ TVM_DLL runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target_host,
const BuildConfig& config);

/*!
* \brief Build a device and host module for a specific target from a map
* contains target to a list of lowered functions pairs. This function is used
* for heterogeneous build.
* \param input The map contains target to a list of lowered functions pairs.
* \param target_host The target for building host code. To use the default,
* pass Target().
* \param config The build configuration.
* \return The built module that contains code for different processors.
*/
TVM_DLL runtime::Module build(const Map<Target, Array<LoweredFunc>>& input,
const Target& target_host,
const BuildConfig& config);

/*!
* \brief Build a device and host module for a specific target from a map
* contains target to a list of lowered functions pairs. This function is used
* for heterogeneous build.
* \param input The map contains target string to a list of lowered functions
* pairs.
* \param target_host The target for building host code. To use the default,
* pass Target().
* \param config The build configuration.
* \return The built module that contains code for different processors.
*/
TVM_DLL runtime::Module build(const Map<std::string, Array<LoweredFunc>>& input,
const Target& target_host,
const BuildConfig& config);

class GenericFuncNode;

/*!
Expand Down
126 changes: 98 additions & 28 deletions src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,20 +428,19 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
const Target& target_host,
const BuildConfig& config) {
std::unordered_set<std::string> all_names;
for (const auto &x : funcs) {
CHECK(all_names.count(x->name) == 0) << "Duplicate function name " << x->name;
for (const auto& x : funcs) {
CHECK(all_names.count(x->name) == 0)
<< "Duplicate function name " << x->name;
all_names.insert(x->name);
}

auto target_host_val = target_host.defined() ? target_host : DefaultTargetHost(target);

Array<LoweredFunc> fhost;
Array<LoweredFunc> fdevice;

for (const auto& x : funcs) {
CHECK(ir::VerifyMemory(x, target->device_type))
<< "Direct host side access to device memory is detected in " << x->func_name()
<< ". Did you forget to bind?";
<< "Direct host side access to device memory is detected in "
<< x->func_name() << ". Did you forget to bind?";

if (x->func_type == kMixedFunc) {
auto func = x;
Expand All @@ -450,6 +449,7 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
}

func = ir::ThreadSync(func, "shared");
func = ir::ThreadSync(func, "warp");
func = ir::LowerThreadAllreduce(func, target->thread_warp_size);
auto fsplits = ir::SplitHostDevice(func);
fhost.push_back(fsplits[0]);
Expand All @@ -465,12 +465,32 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
}
}

for (size_t i = 0; i < fdevice.size(); i++) {
auto warp_size = target->thread_warp_size;
auto func = fdevice[i];
func = ir::LowerWarpMemory(fdevice[i], warp_size);
fdevice.Set(i, func);
}

auto keys = target->keys();
bool target_is_gpu =
std::find(keys.begin(), keys.end(), "gpu") != keys.end();
bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
if (target_is_gpu && fdevice.size() == 0) {
LOG(WARNING) << "Specified target " + target->str() +
" but cannot find device code. Did you forget to bind?";
LOG(WARNING) << "Specified target "
<< target->str()
<< " but cannot find device code. Did you forget to bind?";
}

for (size_t i = 0; i < fdevice.size(); ++i) {
auto func = fdevice[i];
func = ir::LowerIntrin(func, target->target_name);
fdevice.Set(i, func);
}

if (target->device_type == target::llvm()->device_type &&
target_host == target) {
CHECK(fdevice.empty()) << "No device code should be generated when target "
<< "and host_target are both llvm target."
<< "\n";
}

for (size_t i = 0; i < fhost.size(); ++i) {
Expand All @@ -480,41 +500,91 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
fhost.Set(i, func);
}


for (size_t i = 0; i < fdevice.size(); ++i) {
auto func = fdevice[i];
func = ir::LowerIntrin(func, target->target_name);
fdevice.Set(i, func);
}

for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i];
func = ir::LowerIntrin(func, target_host_val->target_name);
func = ir::LowerIntrin(func, target_host->target_name);
func = ir::CombineContextCall(func);
fhost.Set(i, func);
}
return {fhost, fdevice};
}

runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
// Create a module for a specific device (target). The lowered functions
// associated with the host is returned as well.
runtime::Module DeviceBuild(const Array<LoweredFunc>& fdevice,
const Target& target) {
if (!fdevice.empty()) {
return codegen::Build(fdevice, target->str());
} else {
return runtime::Module(nullptr);
}
}

// Build for heterogeneous execution.
runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs,
const Target& target_host,
const BuildConfig& config) {
auto target_host_val = target_host.defined() ? target_host : DefaultTargetHost(target);
auto host_dev_funcs = split_dev_host_funcs(funcs, target, target_host, config);
auto& fhost = host_dev_funcs[0];
auto& fdevice = host_dev_funcs[1];
Array<LoweredFunc> fhost_all;
std::vector<runtime::Module> device_modules;

Target target_host_val = target_host;
if (!target_host.defined()) {
for (const auto& it : inputs) {
if (it.first->device_type == kDLCPU) {
target_host_val = it.first;
break;
}
}
}

auto mhost = codegen::Build(fhost, target_host_val->str());
if (!target_host_val.defined()) {
target_host_val = DefaultTargetHost(target_host_val);
}

if (fdevice.size() > 0) {
auto mdev = codegen::Build(fdevice, target->str());
mhost.Import(mdev);
for (const auto& it : inputs) {
auto host_dev_funcs =
split_dev_host_funcs(it.second, it.first, target_host_val, config);
auto& fhost = host_dev_funcs[0];
auto& fdevice = host_dev_funcs[1];
// Get the module for a certain target.
runtime::Module mdev = DeviceBuild(fdevice, it.first);
for (const auto& it : fhost) {
fhost_all.push_back(it);
}
device_modules.push_back(mdev);
}

runtime::Module mhost = codegen::Build(fhost_all, target_host_val->str());
// Import all modules
for (const auto& it : device_modules) {
if (it.operator->()) {
mhost.Import(it);
}
}
return mhost;
}

// Build for heterogeneous execution when target is a string.
runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs,
const Target& target_host,
const BuildConfig& config) {
Map<Target, Array<LoweredFunc>> updated_input;
for (const auto& it : inputs) {
auto target = Target::create(it.first);
updated_input.Set(target, it.second);
}
return build(updated_input, target_host, config);
}

// Build for homogeneous execution.
runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
Map<Target, Array<LoweredFunc>> inputs = {{target, funcs}};
return build(inputs, target_host, config);
}

BuildConfig build_config() {
return BuildConfig(make_node<BuildConfigNode>());
}
Expand Down
51 changes: 2 additions & 49 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -601,52 +601,6 @@ class RelayBuildModule : public runtime::ModuleNode {
}
return func;
}
/*!
* \brief Build module given lowered functions for each target
*
* \param lowered_funcs target_str -> Array<LoweredFunc> map
* \param targets Targets map
* \param cfg Building configuration
*/
void BuildModule(const Map<std::string, Array<LoweredFunc> >& lowered_funcs,
const Map<HalideIR::Expr, HalideIR::Expr>& targets,
const BuildConfig& cfg) {
auto target_host = Target::create(cfg_.fallback_device);
for (const auto& kv : lowered_funcs) {
std::unordered_set<std::string> fname_set;
for (auto f : kv.second) {
if (fname_set.count(f->name)) {
LOG(FATAL) << "Duplicate function name "
<< f->name;
}
fname_set.insert(f->name);
}
}
std::unordered_map<std::string, Target> target_map;
for (const auto& kv : lowered_funcs) {
target_map[kv.first] = Target::create(kv.first);
}
Array<LoweredFunc> fhost_all;
std::vector<runtime::Module> device_module;
for (const auto& kv : lowered_funcs) {
auto target = target_map[kv.first];
auto host_dev_funcs = split_dev_host_funcs(kv.second, target, target_host, cfg);
for (auto f : host_dev_funcs[0]) {
fhost_all.push_back(f);
}
if (host_dev_funcs[1].size()) {
auto mdev = codegen::Build(host_dev_funcs[1], target->str());
device_module.push_back(mdev);
}
}

auto mhost = codegen::Build(fhost_all, target_host->str());

for (auto mdev : device_module) {
mhost.Import(mdev);
}
ret_.mod = mhost;
}

/*!
* \brief Build relay function to runtime module
Expand Down Expand Up @@ -686,9 +640,8 @@ class RelayBuildModule : public runtime::ModuleNode {
ret_.graph_json = graph_codegen_->GetJSON();
ret_.params = graph_codegen_->GetParams();

BuildModule(graph_codegen_->GetLoweredFunc(),
device_target,
tvm_cfg_);
auto target_host = Target::create(target_host_);
ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host, tvm_cfg_);
}

protected:
Expand Down
Loading

0 comments on commit 4608e10

Please sign in to comment.