Skip to content

Commit

Permalink
[RELAY] Hotfix build_module creation (#3198)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored May 16, 2019
1 parent 493f90f commit ac3f5bd
Showing 1 changed file with 23 additions and 38 deletions.
61 changes: 23 additions & 38 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
*/

/*!
* Copyright (c) 2019 by Contributors
* \file relay/backend/build_module.cc
* \brief Code generation for TVM's graph runtime.
*/

#include <tvm/build_module.h>
#include <tvm/runtime/device_api.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/nn.h>
Expand All @@ -40,31 +39,6 @@ namespace backend {

using TargetsMap = Map<tvm::Integer, tvm::Target>;

/*!
* \brief Context index to Target
*/
struct ContextTargetMap {
static const std::unordered_map<int, tvm::Target> mask2str;
static tvm::Target Mask2Str(int mask) {
CHECK_GT(mask2str.count(mask), 0) << "Unknown mask.";
return mask2str.at(mask);
}
};

const std::unordered_map<int, tvm::Target> ContextTargetMap::mask2str = {
{1, tvm::Target::create("llvm")},
{2, tvm::Target::create("cuda")},
{4, tvm::Target::create("opencl")},
{5, tvm::Target::create("aocl")},
{6, tvm::Target::create("sdaccel")},
{7, tvm::Target::create("vulkan")},
{8, tvm::Target::create("metal")},
{9, tvm::Target::create("vpi")},
{10, tvm::Target::create("rocm")},
{11, tvm::Target::create("opengl")},
{12, tvm::Target::create("ext_dev")}
};

/*!
* \brief A data structure to map the names of specific optimizations to
* numeric optimization levels
Expand Down Expand Up @@ -310,8 +284,8 @@ class RelayBuildModule : public runtime::ModuleNode {
*
* \return Array<StringImm> names of params
*/
Array<HalideIR::Expr> ListParamNames() {
Array<HalideIR::Expr> ret;
Array<tvm::Expr> ListParamNames() {
Array<tvm::Expr> ret;
for (const auto& kv : params_) {
ret.push_back(ir::StringImm::make(kv.first));
}
Expand Down Expand Up @@ -470,12 +444,9 @@ class RelayBuildModule : public runtime::ModuleNode {
if (cfg.pass_enabled("AlterOpLayout")) {
if (targets.size() == 1) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
auto enter_pf = GetPackedFunc("_EnterTargetScope");
auto exit_pf = GetPackedFunc("_ExitTargetScope");
for (const auto& kv : targets) {
(*enter_pf)(kv.second);
TargetContext tctx(kv.second);
func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
(*exit_pf)();
}
} else {
LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous"
Expand All @@ -487,6 +458,18 @@ class RelayBuildModule : public runtime::ModuleNode {
}
return func;
}

/*!
* \brief Create a default type.
* \param device_type The device type index.
* \return the default target for the device.
*/
Target CreateDefaultTarget(int device_type) {
std::string name = runtime::DeviceName(device_type);
if (name == "cpu") return Target::create("llvm");
if (name == "gpu") return Target::create("cuda");
return Target::create(name);
}
/*!
* \brief Update the target and fallback device required for heterogeneous
* compilation. CPU is used as the fallback device if it wasn't provided.
Expand All @@ -507,7 +490,7 @@ class RelayBuildModule : public runtime::ModuleNode {
if (tmp_map.count(cfg.fallback_device) == 0) {
device_target.Set(
cfg.fallback_device,
ContextTargetMap::Mask2Str(cfg.fallback_device));
CreateDefaultTarget(cfg.fallback_device));
}
return device_target;
}
Expand All @@ -520,7 +503,8 @@ class RelayBuildModule : public runtime::ModuleNode {
* \param targets_map_ptr
* \return Function
*/
Function RunDeviceAnnotationPass(Function func, const RelayBuildConfig& cfg,
Function RunDeviceAnnotationPass(Function func,
const RelayBuildConfig& cfg,
TargetsMap* targets_map_ptr) {
func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func,
Expand All @@ -532,7 +516,7 @@ class RelayBuildModule : public runtime::ModuleNode {
"relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr);
if (annotation_map.size() == 0) {
targets_map_ptr->Set(
0, ContextTargetMap::Mask2Str(cfg.fallback_device));
0, CreateDefaultTarget(cfg.fallback_device));
} else {
int64_t dev_type = -1;
for (auto kv : annotation_map) {
Expand All @@ -547,7 +531,7 @@ class RelayBuildModule : public runtime::ModuleNode {
<< "found. Please check the "
<< "RewriteAnnotation pass.";
}
targets_map_ptr->Set(0, ContextTargetMap::Mask2Str(dev_type));
targets_map_ptr->Set(0, CreateDefaultTarget(dev_type));
}
}
return func;
Expand Down Expand Up @@ -611,7 +595,8 @@ runtime::Module RelayBuildCreate() {
return runtime::Module(exec);
}

TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) {
TVM_REGISTER_GLOBAL("relay.build_module._BuildModule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RelayBuildCreate();
});

Expand Down

0 comments on commit ac3f5bd

Please sign in to comment.