diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 5b5d00b5a9617..334fe169ad41d 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -344,6 +344,19 @@ TVM_DLL Array lower(Schedule sch, const std::string& name, const std::unordered_map& binds, const BuildConfig& config); +/*! +* \brief Split host/device function and running necessary pass before build +* \param funcs The functions to be built. +* \param target The target device to build for. +* \param target_host The target for building host code. To use the default, pass Target() +* \param config The build configuration. +* \return The Array> with 2 elements. First is host function Array, + second is device function array +*/ +TVM_DLL Array > split_dev_host_funcs(const Array& funcs, + const Target& target, + const Target& target_host, + const BuildConfig& config); /*! * \brief Build a device and host module for a specific target from an array of lowered functions. @@ -351,16 +364,12 @@ TVM_DLL Array lower(Schedule sch, * \param target The target device to build for. * \param target_host The target for building host code. To use the default, pass Target() * \param config The build configuration. -* \param (optional) returned host functions -* \param (optional) returned dev mods * \return The built module. */ TVM_DLL runtime::Module build(const Array& funcs, const Target& target, const Target& target_host, - const BuildConfig& config, - Array* fhost_ret = nullptr, - std::vector* devmod_ret = nullptr); + const BuildConfig& config); class GenericFuncNode; diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 01057ddfe677d..07ef22e0c296f 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -422,12 +422,10 @@ Array lower(Schedule sch, return Array({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) }); } -runtime::Module build(const Array& funcs, - const Target& target, - const Target& target_host, - const BuildConfig& config, - Array* fhost_ret, - std::vector* devmod_ret) { +Array > split_dev_host_funcs(const Array& funcs, + const Target& target, + const Target& target_host, + const BuildConfig& config) { std::unordered_set all_names; for (const auto &x : funcs) { CHECK(all_names.count(x->name) == 0) << "Duplicate function name " << x->name; @@ -466,12 +464,6 @@ runtime::Module build(const Array& funcs, } } - if (fhost_ret != nullptr) { - for (auto f : fhost) { - fhost_ret->push_back(f); - } - } - auto keys = target->keys(); bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); @@ -500,14 +492,25 @@ runtime::Module build(const Array& funcs, func = ir::CombineContextCall(func); fhost.Set(i, func); } + Array > ret; + ret.push_back(fhost); + ret.push_back(fdevice); + return ret; +} + +runtime::Module build(const Array& funcs, + const Target& target, + const Target& target_host, + const BuildConfig& config) { + auto target_host_val = target_host.defined() ? target_host : DefaultTargetHost(target); + auto host_dev_funcs = split_dev_host_funcs(funcs, target, target_host, config); + auto& fhost = host_dev_funcs[0]; + auto& fdevice = host_dev_funcs[1]; auto mhost = codegen::Build(fhost, target_host_val->str()); if (fdevice.size() > 0) { auto mdev = codegen::Build(fdevice, target->str()); - if (devmod_ret != nullptr) { - devmod_ret->push_back(mdev); - } mhost.Import(mdev); } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 710dfb010ace0..3b8636af7f3a9 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -18,13 +18,11 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * \file relay/backend/build_module.cc - * \brief Graph runtime codegen + * \brief Code generation for TVM's graph runtime. */ - - #include #include #include @@ -42,6 +40,7 @@ namespace backend { /*! * \brief Context name / index + * See: python/tvm/_ffi/runtime_ctypes.py */ struct ContextMap { static const std::unordered_map mask2str; @@ -98,7 +97,11 @@ const std::unordered_map ContextMap::mask2str = const std::unordered_map ContextMap::str2mask = ContextMap::_declare_str2mask(); -/*! \brief Optimization pass level */ +/*! + * \brief A data structure to map the names of specific optimizations to + * numeric optimization levels + * + */ struct OptPassLevel { static const std::unordered_map _data; static std::unordered_map _declare_opt_level() { @@ -113,6 +116,12 @@ struct OptPassLevel { ret["EliminateCommonSubexpr"] = 3; return ret; } + /*! + * \brief Get level for an optimization pass + * + * \param key pass name + * \return int level + */ int operator[](const std::string& key) const { auto it = _data.find(key); if (it == _data.end()) { @@ -125,21 +134,27 @@ struct OptPassLevel { const std::unordered_map OptPassLevel::_data = OptPassLevel::_declare_opt_level(); -/*! \brief Output of function building */ +/*! + * \brief Output of building module + * + */ struct BuildOutput { std::string graph_json; runtime::Module mod; std::unordered_map params; }; -/*! \brief Relay Building configuration */ +/*! + * \brief Relay building config + * + */ struct RelayBuildConfig { int opt_level{2}; - std::string fall_back_device{"llvm"}; + std::string fallback_device{"llvm"}; std::unordered_set add_pass; std::unordered_set disabled_pass; OptPassLevel OPT_PASS_LEVEL; - inline bool pass_enabled(std::string pass_name) const { + inline bool pass_enabled(const std::string& pass_name) const { if (disabled_pass.count(pass_name)) { return false; } @@ -150,7 +165,10 @@ struct RelayBuildConfig { } }; -/*! \brief GraphCodegen wrapper */ +/*! + * \brief GraphCodegen module wrapper + * + */ struct GraphCodegen { public: GraphCodegen() { @@ -166,27 +184,27 @@ struct GraphCodegen { tgts.push_back(kv.first); tgts.push_back(kv.second); } - _CallFunc("init", m, tgts); + CallFunc("init", m, tgts); } - void Codegen(Function func) { - _CallFunc("codegen", func); + void Codegen(const Function& func) { + CallFunc("codegen", func); } std::string GetJSON() { - return _CallFunc("get_graph_json", nullptr); + return CallFunc("get_graph_json", nullptr); } Map > GetLoweredFunc() { - return _CallFunc > >("get_lowered_funcs", nullptr); + return CallFunc > >("get_lowered_funcs", nullptr); } std::unordered_map GetParams() { std::unordered_map ret; - auto names = _CallFunc >("list_params_name", nullptr); + auto names = CallFunc >("list_params_name", nullptr); for (auto expr : names) { auto key = expr.as()->value; - ret[key] = _CallFunc("get_param_by_name", key); + ret[key] = CallFunc("get_param_by_name", key); } return ret; } @@ -194,12 +212,12 @@ struct GraphCodegen { protected: tvm::runtime::Module mod; template - R _CallFunc(const std::string &name, Args... args) { + R CallFunc(const std::string &name, Args... args) { auto pf = mod.GetFunction(name, false); return pf(std::forward(args)...); } template - void _CallFunc(const std::string &name, Args... args) { + void CallFunc(const std::string &name, Args... args) { auto pf = mod.GetFunction(name, false); pf(std::forward(args)...); return; @@ -207,18 +225,21 @@ struct GraphCodegen { }; template -R _CallPacked(const std::string &name, Args... args) { +R CallPackedFunc(const std::string &name, Args... args) { auto pf = GetPackedFunc(name); return (*pf)(std::forward(args)...); } template -Function _CallPacked(const std::string &name, Args... args) { +Function CallPackedFunc(const std::string &name, Args... args) { auto pf = GetPackedFunc(name); return (*pf)(std::forward(args)...); } - +/*! + * \brief Relay build module + * + */ class RelayBuildModule : public runtime::ModuleNode { public: /*! @@ -231,11 +252,11 @@ class RelayBuildModule : public runtime::ModuleNode { const std::shared_ptr& sptr_to_self) { if (name == "get_graph_json") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->_GetGraphJSON(); + *rv = this->GetGraphJSON(); }); } else if (name == "get_module") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->_GetModule(); + *rv = this->GetModule(); }); } else if (name == "build") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -247,40 +268,48 @@ class RelayBuildModule : public runtime::ModuleNode { auto v = tmp[i + 1].as()->value; targets[k] = v; } - this->_Build(args[0], targets, args[2]); + this->Build(args[0], targets, args[2]); }); } else if (name == "list_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->_ListParamNames(); + *rv = this->ListParamNames(); }); - } else if (name == "get_param_by_name") { + } else if (name == "get_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.num_args, 1); - *rv = this->_GetParam(args[0]); + *rv = this->GetParams(); }); } else if (name == "set_opt_level") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 1); int level = args[0]; - this->_SetOptLevel(level); + this->SetOptLevel(level); }); } else if (name == "set_fallback_device") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { std::string dev = args[0]; - this->_SetFallBackDev(dev); + this->SetFallBackDev(dev); }); } else if (name == "add_pass") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { std::string pass_name = args[0]; - this->_AddPass(pass_name); + this->AddPass(pass_name); }); } else if (name == "disable_pass") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { std::string pass_name = args[0]; - this->_DisablePass(pass_name); + this->DisablePass(pass_name); + }); + } else if (name == "set_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + Map params = args[0]; + for (const auto& kv : params) { + this->SetParam(kv.first, kv.second->data); + } }); } else { - return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) {}); + return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) { + LOG(FATAL) << "Unknown packed function: " << name; + }); } } @@ -289,40 +318,48 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return const std::string graph_json */ - const std::string _GetGraphJSON() { + const std::string GetGraphJSON() { return ret_.graph_json; } /*! - * \brief Add extra pass during build + * \brief Add extra pass into build cfg * - * \param pass_name + * \param pass_name name of pass */ - void _AddPass(const std::string& pass_name) { + void AddPass(const std::string& pass_name) { cfg_.add_pass.insert(pass_name); } - - void _DisablePass(const std::string& pass_name) { + /*! + * \brief Disable a specific pass in cfg + * + * \param pass_name name of pass + */ + void DisablePass(const std::string& pass_name) { cfg_.disabled_pass.insert(pass_name); } - - void _SetFallBackDev(const std::string& dev) { - cfg_.fall_back_device = dev; + /*! + * \brief Set the Fallback device + * + * \param device name + */ + void SetFallBackDev(const std::string& dev) { + cfg_.fallback_device = dev; } /*! * \brief Get the Module object * * \return runtime::Module */ - runtime::Module _GetModule() { + runtime::Module GetModule() { return ret_.mod; } /*! * \brief List all paramter names * - * \return Array + * \return Array names of params */ - Array _ListParamNames() { + Array ListParamNames() { Array ret; for (const auto& kv : params_) { ret.push_back(ir::StringImm::make(kv.first)); @@ -331,14 +368,16 @@ class RelayBuildModule : public runtime::ModuleNode { } /*! - * \brief Get the Param of name + * \brief Get params dictionary * - * \param name - * \return runtime::NDArray + * \return Map params dictionary */ - runtime::NDArray _GetParam(const std::string& name) { - CHECK_GT(params_.count(name), 0) << "Can not find param with name: " << name; - return params_[name]; + Map GetParams() { + Map ret; + for (const auto& kv : ret_.params) { + ret.Set(kv.first, ConstantNode::make(kv.second)); + } + return ret; } /*! @@ -347,12 +386,8 @@ class RelayBuildModule : public runtime::ModuleNode { * \param name name of parameter * \param data_in input DLTensor */ - void _SetParams(const std::string& name, DLTensor* data_in) { - if (!params_.count(name)) { - std::vector shape(data_in->shape, data_in->shape + data_in->ndim); - params_[name] = tvm::runtime::NDArray::Empty(shape, data_in->dtype, {kDLCPU, 0}); - } - params_[name].CopyFrom(data_in); + void SetParam(const std::string& name, runtime::NDArray data_in) { + params_[name] = data_in; } /*! @@ -360,7 +395,7 @@ class RelayBuildModule : public runtime::ModuleNode { * * \param level */ - void _SetOptLevel(char level) { + void SetOptLevel(char level) { cfg_.opt_level = level; } @@ -380,23 +415,22 @@ class RelayBuildModule : public runtime::ModuleNode { * \param target Target device * \param target_host Host target device */ - void _Build(Function func, + void Build(Function func, const std::unordered_map& targets, const std::string& target_host) { targets_ = targets; target_host_ = target_host; - _BuildRelay(func, cfg_, params_); + BuildRelay(func, cfg_, params_); } protected: /*! - * \brief bind params to function - * + * \brief Bind params to function by using name * \param func Relay function * \param params params dict * \return relay::Function */ - relay::Function _bind_params_by_name(relay::Function func, + relay::Function BindParamsByName(relay::Function func, const std::unordered_map& params) { std::unordered_map name_dict; std::unordered_set repeat_var; @@ -418,10 +452,10 @@ class RelayBuildModule : public runtime::ModuleNode { if (repeat_var.count(arg)) { LOG(FATAL) << "Multiple args in the function have name " << kv.first; } - auto e = _CallPacked("relay._make.Constant", kv.second); + auto e = CallPackedFunc("relay._make.Constant", kv.second); bind_dict[arg] = e; } - return _CallPacked("relay._expr.Bind", func, tvm::Map(bind_dict)); + return CallPackedFunc("relay._expr.Bind", func, tvm::Map(bind_dict)); } /*! @@ -433,16 +467,16 @@ class RelayBuildModule : public runtime::ModuleNode { * \param params params dict * \return relay::Function */ - relay::Function _Optimize(relay::Function func, - const std::unordered_map& targets, - const RelayBuildConfig& cfg, - const std::unordered_map& params) { + relay::Function Optimize(relay::Function func, + const std::unordered_map& targets, + const RelayBuildConfig& cfg, + const std::unordered_map& params) { if (params.size()) { - func = _bind_params_by_name(func, params); + func = BindParamsByName(func, params); } if (cfg.pass_enabled("SimplifyInference")) { - func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); - func = _CallPacked("relay._ir_pass.simplify_inference", func); + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.simplify_inference", func); } if (cfg.pass_enabled("EliminateCommonSubexpr")) { auto fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { @@ -453,60 +487,70 @@ class RelayBuildModule : public runtime::ModuleNode { if (op_node->name == "cast") { auto attrs = call_node->attrs.as(); if (attrs->dtype == HalideIR::Int(32)) { - return true; + *rv = true; } } } - return false; + *rv = false; }); - func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); - func = _CallPacked("relay._ir_pass.eliminate_common_subexpr", func, fskip); + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.eliminate_common_subexpr", func, fskip); } if (cfg.pass_enabled("CombineParallelConv2D")) { - func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); - func = _CallPacked("relay._ir_pass.CombineParallelConv2D", func); + const int min_num_branches = 3; + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.CombineParallelConv2D", func, min_num_branches); } if (cfg.pass_enabled("FoldConstant")) { - func = _CallPacked("relay._ir_pass.FoldConstant", func); + func = CallPackedFunc("relay._ir_pass.FoldConstant", func); } if (cfg.pass_enabled("FoldScaleAxis")) { - func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); - func = _CallPacked("relay._ir_pass.backward_fold_scale_axis", func); - func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); - func = _CallPacked("relay._ir_pass.forward_fold_scale_axis", func); - func = _CallPacked("relay._ir_pass.FoldConstant", func); + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.backward_fold_scale_axis", func); + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.forward_fold_scale_axis", func); + func = CallPackedFunc("relay._ir_pass.FoldConstant", func); } if (cfg.pass_enabled("CanonicalizeOps")) { - func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); - func = _CallPacked("relay._ir_pass.canonicalize_ops", func); + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.canonicalize_ops", func); } if (cfg.pass_enabled("AlterOpLayout")) { if (targets.size() == 1) { - func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); - func = _CallPacked("relay._ir_pass.AlterOpLayout", func); + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func); } else { LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous" << " execution yet."; } } if (cfg.pass_enabled("FoldConstant")) { - func = _CallPacked("relay._ir_pass.FoldConstant", func); + func = CallPackedFunc("relay._ir_pass.FoldConstant", func); } return func; } - - Map _UpdateHeterogeneousInputs( + /*! + * \brief Update the target and fallback device required for heterogeneous + * compilation. CPU is used as the fallback device if it wasn't provided. + * Meanwhile, a CPU device type and "llvm" pair will be added to the target + * dictionary in this case. + * + * \param targets dictionary + * \param cfg + * \return Map + */ + Map UpdateHeterogeneousInputs( const std::unordered_map& targets, const RelayBuildConfig& cfg) { Map device_target; std::unordered_map tmp_map; - auto fallback_idx = ContextMap::Str2Mask(cfg.fall_back_device); + auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device); for (const auto& kv : targets) { tmp_map[ContextMap::Str2Mask(kv.first)] = kv.second; } if (tmp_map.count(fallback_idx) == 0) { - tmp_map[fallback_idx] = cfg.fall_back_device; + tmp_map[fallback_idx] = cfg.fallback_device; } for (const auto& kv : tmp_map) { device_target.Set( @@ -515,25 +559,34 @@ class RelayBuildModule : public runtime::ModuleNode { } return device_target; } - - Function _RunDeviceAnnotationPass( - Function func, - const RelayBuildConfig& cfg, - Map* targets_map_ptr) { - auto fallback_idx = ContextMap::Str2Mask(cfg.fall_back_device); - func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); - func = _CallPacked("relay._ir_pass.RewriteDeviceAnnotation", func, fallback_idx); - auto device_map = _CallPacked >("relay._ir_pass.CollectDeviceInfo", + /*! + * \brief Execute the device annotation passes to update the input program and + * target information. + * + * \param func + * \param cfg + * \param targets_map_ptr + * \return Function + */ + Function RunDeviceAnnotationPass( + Function func, + const RelayBuildConfig& cfg, + Map* targets_map_ptr) { + auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device); + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func, fallback_idx); + auto device_map = CallPackedFunc >("relay._ir_pass.CollectDeviceInfo", func, nullptr); if (device_map.size() == 0) { - auto annotation_map = _CallPacked >("_ir_pass.CollectDeviceAnnotationOps", - func, - nullptr); + auto annotation_map = + CallPackedFunc >("relay._ir_pass.CollectDeviceAnnotationOps", + func, + nullptr); if (annotation_map.size() == 0) { targets_map_ptr->Set( ir::IntImm::make(HalideIR::Int(64), 0), - ir::StringImm::make(cfg.fall_back_device)); + ir::StringImm::make(cfg.fallback_device)); } else { int64_t dev_type = -1; for (auto kv : annotation_map) { @@ -556,16 +609,16 @@ class RelayBuildModule : public runtime::ModuleNode { return func; } /*! - * \brief Build module given lowered functions + * \brief Build module given lowered functions for each target * - * \param lowered_funcs - * \param targets - * \param cfg + * \param lowered_funcs target_str -> Array map + * \param targets Targets map + * \param cfg Building configuration */ - void _BuildModule(Map > lowered_funcs, - Map targets, - const BuildConfig& cfg) { - auto target_host = Target::create(cfg_.fall_back_device); + void BuildModule(const Map >& lowered_funcs, + const Map& targets, + const BuildConfig& cfg) { + auto target_host = Target::create(cfg_.fallback_device); for (const auto& kv : lowered_funcs) { std::unordered_set fname_set; for (auto f : kv.second) { @@ -584,18 +637,17 @@ class RelayBuildModule : public runtime::ModuleNode { std::vector device_module; for (const auto& kv : lowered_funcs) { auto target = target_map[kv.first]; - auto mdev = build(kv.second, - target, - target_host, - cfg, - &fhost_all, - &device_module); + auto host_dev_funcs = split_dev_host_funcs(kv.second, target, target_host, cfg); + for (auto f : host_dev_funcs[0]) { + fhost_all.push_back(f); + } + if (host_dev_funcs[1].size()) { + auto mdev = codegen::Build(host_dev_funcs[1], target->str()); + device_module.push_back(mdev); + } } - auto mhost = build(fhost_all, - target_host, - target_host, - cfg); + auto mhost = codegen::Build(fhost_all, target_host->str()); for (auto mdev : device_module) { mhost.Import(mdev); @@ -607,25 +659,32 @@ class RelayBuildModule : public runtime::ModuleNode { * \brief Build relay function to runtime module * * \param func Relay Function - * \param target target device - * \param target_host host device * \param cfg Relay build config - * \param params params + * \param params parameters * \return BuildOutput */ - void _BuildRelay(relay::Function func, - const RelayBuildConfig& cfg, - const std::unordered_map ¶ms) { + void BuildRelay(Function func, + const RelayBuildConfig& cfg, + const std::unordered_map ¶ms) { // convert tvm_cfg_ = build_config(); - auto device_target = _UpdateHeterogeneousInputs(targets_, cfg); - func = _Optimize(func, targets_, cfg, params); + Map device_target; if (targets_.size() > 1) { - func = _RunDeviceAnnotationPass(func, cfg, &device_target); + device_target = UpdateHeterogeneousInputs(targets_, cfg); + } else { + for (auto &kv : targets_) { + device_target.Set( + ir::IntImm::make(HalideIR::Int(64), ContextMap::Str2Mask(kv.first)), + ir::StringImm::make(kv.second)); + } + } + func = Optimize(func, targets_, cfg, params); + if (device_target.size() > 1) { + func = RunDeviceAnnotationPass(func, cfg, &device_target); } - func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); - func = _CallPacked("relay._ir_pass.FuseOps", func, cfg.opt_level); - func = _CallPacked("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + func = CallPackedFunc("relay._ir_pass.FuseOps", func, cfg.opt_level); + func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); graph_codegen_ = std::unique_ptr(new GraphCodegen()); graph_codegen_->Init(nullptr, device_target); @@ -634,9 +693,9 @@ class RelayBuildModule : public runtime::ModuleNode { ret_.graph_json = graph_codegen_->GetJSON(); ret_.params = graph_codegen_->GetParams(); - _BuildModule(graph_codegen_->GetLoweredFunc(), - device_target, - tvm_cfg_); + BuildModule(graph_codegen_->GetLoweredFunc(), + device_target, + tvm_cfg_); } protected: diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 7f16891da8a7a..415e0ec9c2a55 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -416,7 +416,12 @@ class GraphRuntimeCodegen } else { // heterogeneous execution. const auto call_dev_key = std::to_string(call_dev_type); - const auto call_dev_name = runtime::DeviceName(call_dev_type); + std::string call_dev_name; + if (call_dev_type == 0) { + call_dev_name = "llvm"; + } else { + call_dev_name = runtime::DeviceName(call_dev_type); + } if (targets_.count(call_dev_name) == 0 && targets_.count(call_dev_key) == 0) { LOG(FATAL) << "No target is provided for device " << call_dev_name; diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc new file mode 100644 index 0000000000000..3c421bf1dcf09 --- /dev/null +++ b/tests/cpp/relay_build_module_test.cc @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +TEST(Relay, BuildModule) { + using namespace tvm; + auto tensor_type = relay::TensorTypeNode::make({2, 3}, ::tvm::Float(32)); + auto a = relay::VarNode::make("a", tensor_type); + auto b = relay::VarNode::make("b", tensor_type); + auto add_op = relay::Op::Get("add"); + auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {}); + auto c = relay::VarNode::make("c", tensor_type); + auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {}); + auto func = relay::FunctionNode::make(relay::FreeVars(y), y, relay::Type(), {}); + auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + // auto Y = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto pA = (float*)A.ToDLPack()->dl_tensor.data; + auto pB = (float*)B.ToDLPack()->dl_tensor.data; + auto pC = (float*)C.ToDLPack()->dl_tensor.data; + // auto pY = (float*)Y.ToDLPack()->dl_tensor.data; + for (int i = 0; i < 6; ++i) { + pA[i] = i; + pB[i] = i + 1; + pC[i] = i + 2; + // pY[i] = 0; + } + // build + auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule"); + tvm::runtime::Module build_mod = (*pfb)(); + auto build_f = build_mod.GetFunction("build", false); + auto json_f = build_mod.GetFunction("get_graph_json", false); + auto mod_f = build_mod.GetFunction("get_module", false); + build_f(func, "llvm", "llvm"); + std::string json = json_f(); + tvm::runtime::Module mod = mod_f(); + // run + auto ctx = A->ctx; + auto pfr = tvm::runtime::Registry::Get("tvm.graph_runtime.create"); + tvm::runtime::Module run_mod = (*pfr)(json, mod, (int)ctx.device_type, (int)ctx.device_id); + auto set_input_f = run_mod.GetFunction("set_input", false); + auto run_f = run_mod.GetFunction("run", false); + auto get_output_f = run_mod.GetFunction("get_output", false); + set_input_f("a", A); + set_input_f("b", B); + set_input_f("c", C); + run_f(); + tvm::runtime::NDArray Y = get_output_f(0); + auto pY = (float*)Y.ToDLPack()->dl_tensor.data; + for (int i = 0; i < 6; ++i) { + CHECK_GT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4); + } +} + +int main(int argc, char ** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py index 4728c4f936129..56bfa07431d12 100644 --- a/tests/python/relay/test_cpp_build_module.py +++ b/tests/python/relay/test_cpp_build_module.py @@ -23,55 +23,85 @@ _init_api("tvm.relay.build_module") class BuildModule(object): - def __init__(self): - self.mod = relay.build_module._BuildModule() - self._get_graph_json = self.mod["get_graph_json"] - self._get_module = self.mod["get_module"] - self._build = self.mod["build"] + def __init__(self): + self.mod = relay.build_module._BuildModule() + self._get_graph_json = self.mod["get_graph_json"] + self._get_module = self.mod["get_module"] + self._build = self.mod["build"] + self._set_opt_level = self.mod["set_opt_level"] + self._set_params_func = self.mod["set_params"] + self._get_params_func = self.mod["get_params"] - def build(self, func, target, target_host): - tgts = [] - for kv in target.items(): - tgts.append(kv[0]) - tgts.append(kv[1]) - self._build(func, tgts, target_host) + + def build(self, func, target, target_host, params): + tgts = [] + for kv in target.items(): + tgts.append(kv[0]) + tgts.append(kv[1]) + self._set_params(params) + self._build(func, tgts, target_host) - def get_json(self): - return self._get_graph_json() + def get_json(self): + return self._get_graph_json() + + def get_module(self): + return self._get_module() + + def set_opt_level(self, level): + self._set_opt_level(level) + + def _set_params(self, params): + inputs = {} + for name, param in params.items(): + inputs[name] = relay.Constant(param) + self._set_params_func(inputs) + + def get_params(self): + params = self._get_params_func() + ret = {} + for key, value in params.items(): + ret[key] = value.data + return ret - def get_module(self): - return self._get_module() def test_build(): - m_bld = BuildModule() - # func - a = relay.var("a", dtype="float32", shape=(16, 8)) - b = relay.var("b", dtype="float32", shape=(8, 8)) - c = relay.var("c", dtype="float32", shape=(16, 8)) - x = relay.nn.dense(a, b) - y = relay.nn.relu(x) - z = y + c - func = relay.Function([a, b, c], z) - # build - targets = { - "cpu": "llvm -mcpu=sse3" - } - m_bld.build(func, targets, "llvm -mcpu=sse3") - g_json = m_bld.get_json() - mmod = m_bld.get_module() - + m_bld = BuildModule() + tgt_name = "llvm" + tgt = "llvm" + ctx = tvm.cpu() + # func + a = relay.var("a", dtype="float32", shape=(16, 8)) + b = relay.var("b", dtype="float32", shape=(8, 8)) + c = relay.var("c", dtype="float32", shape=(16, 8)) + x = relay.nn.dense(a, b) + y = relay.nn.relu(x) + z = y + c + func = relay.Function([a, b, c], z) + A = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"), ctx=ctx) + B = tvm.nd.array(np.random.uniform(-1, 1, (8, 8)).astype("float32"), ctx=ctx) + C = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"), ctx=ctx) + params = { + "b" : B, + "c" : C + } + # build + targets = { + tgt: tgt + } + m_bld.set_opt_level(3) + m_bld.build(func, targets, "llvm -mcpu=sse3", params=params) + g_json = m_bld.get_json() + mmod = m_bld.get_module() + params = m_bld.get_params() - # test - A = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32")) - B = tvm.nd.array(np.random.uniform(-1, 1, (8, 8)).astype("float32")) - C = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32")) - - rt = tvm.contrib.graph_runtime.create(g_json, mmod, tvm.cpu()) - rt.set_input("a", A) - rt.set_input("b", B) - rt.set_input("c", C) - rt.run() - out = rt.get_output(0) + # test + rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx) + rt.set_input("a", A) + rt.load_params(relay.save_param_dict(params)) + rt.run() + out = rt.get_output(0) - np.testing.assert_allclose(out.asnumpy(), - np.maximum(np.dot(A.asnumpy(), B.asnumpy().T), 0) + C.asnumpy(), atol=1e-5, rtol=1e-5) \ No newline at end of file + np.testing.assert_allclose(out.asnumpy(), + np.maximum(np.dot(A.asnumpy(), B.asnumpy().T), 0) + C.asnumpy(), atol=1e-5, rtol=1e-5) + +# test_build()