diff --git a/apps/extension/tests/test_ext.py b/apps/extension/tests/test_ext.py index 994a673298f1e..d387263a06a83 100644 --- a/apps/extension/tests/test_ext.py +++ b/apps/extension/tests/test_ext.py @@ -39,7 +39,7 @@ def test_ext_dev(): def check_llvm(): if not tvm.testing.device_enabled("llvm"): return - f = tvm.build(s, [A, B], "ext_dev", "llvm") + f = tvm.build(s, [A, B], "ext_dev", "ext_dev") dev = tvm.ext_dev(0) # launch the kernel. a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index fffcab49667c1..14ea5119e0e55 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -54,7 +54,8 @@ using tvm::transform::Pass; * \param target The device Target. * \return The composite Pass for the fused module. // */ -TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target); +TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, + Optional target = NullOpt); /*! * \brief Configures and returns the composite Pass for the device Target after device/host from diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index d47b3d4a7de66..8f62323470595 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -209,7 +209,9 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function: primfunc = tir_mod["main"] primfunc = primfunc.with_attr("global_symbol", func.attrs["global_symbol"]) primfunc = primfunc.with_attr("ethos-u.constants", const_dict) - primfunc = primfunc.with_attr("target", tvm.target.Target(compiler_name)) + primfunc = primfunc.with_attr( + "target", tvm.target.Target(compiler_name, host=compiler_name) + ) return primfunc def __call__(self, *args, **kwargs): diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7ea5032fa0cc9..23b239b1fc90d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -279,17 +279,6 @@ Array CreatePassList(bool disable_loop_partition) { return pass_list; } -IRModule LowerWithPassList(IRModule mod, Array pass_list) { - auto optimize = tvm::transform::Sequential(pass_list); - mod = optimize(std::move(mod)); - return mod; -} - -IRModule ApplyPasses(IRModule mod, transform::Sequential seq) { - mod = seq(std::move(mod)); - return mod; -} - // Convert te schedule to IRModule IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, @@ -342,7 +331,8 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module") IRModule LowerModule(IRModule mod, bool simple_mode) { Array pass_list = CreatePassList(simple_mode); - return LowerWithPassList(std::move(mod), pass_list); + tvm::transform::Sequential optimize(pass_list, "tvm.lower"); + return optimize(std::move(mod)); } TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) { @@ -359,10 +349,7 @@ IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_ f = WithAttr(std::move(f), "tir.noalias", Bool(true)); } IRModule mod = IRModule(Map({{GlobalVar(name), f}})); - - // Get the pass list - Array pass_list = CreatePassList(simple_mode); - return LowerWithPassList(std::move(mod), pass_list); + return LowerModule(mod, simple_mode); } TVM_REGISTER_GLOBAL("driver.lower_primfunc") @@ -384,9 +371,7 @@ IRModule LowerSchedule(te::Schedule sch, const Array& args, const std const std::unordered_map& binds, GlobalVarSupply global_var_supply, bool simple_mode) { IRModule mod = ScheduleToModule(std::move(sch), args, name, binds, global_var_supply); - // Get the legacy TE pass list - Array pass_list = CreatePassList(simple_mode); - return LowerWithPassList(mod, pass_list); + return LowerModule(mod, simple_mode); } TVM_REGISTER_GLOBAL("driver.lower_schedule") @@ -403,35 +388,42 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") simple_mode); }); -/** - * This function takes the input module that contains both the device and host opts. - * Then, it applies transformation on the original module before splitting into separate modules for - * device and host. Then it also applies transformations on the new splitted modules. - */ -std::pair SplitMixedModule(IRModule mod_mixed, const Target& target_arg, - const Target& target_host_arg) { - Target target = target_arg, target_host = target_host_arg; - CheckAndUpdateHostConsistency(&target, &target_host); - - ICHECK(mod_mixed.defined()) << "This module must be defined"; +IRModule MergeModules(const Map& inputs) { + if (inputs.size() == 1) { + auto [target, mod] = *inputs.begin(); + return tir::transform::BindTarget(target)(mod); + } - mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target)); + // Take the attrs from the first module so the eventual modules have them. + IRModule first_module = (*inputs.begin()).second; + IRModule merged = IRModule(Map(), {}, {}, {}, first_module->attrs); - IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host)); + for (auto [target, mod] : inputs) { + mod = tir::transform::BindTarget(target)(mod); + merged->Update(mod); + } - IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target)); + return merged; +} - auto keys = target->GetKeys(); +Map SplitModule(const IRModule& module) { + Map split; - CheckAndUpdateHostConsistency(&target, &target_host); + for (auto [gvar, base_func] : module->functions) { + auto target_str = base_func->GetAttr(tvm::attr::kTarget).value()->str(); + if (auto it = split.find(target_str); it != split.end()) { + (*it).second->Add(gvar, base_func); + } else { + split.Set(target_str, IRModule({{gvar, base_func}}, {}, {}, {}, module->attrs)); + } + } - bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); - if (target_is_gpu && device_mod->functions.size() == 0) { - DLOG(WARNING) << "Specified target " << target->str() - << " but cannot find device code. Did you forget to bind?"; + Map out; + for (auto [str, mod] : split) { + out.Set(Target(str), mod); } - return {host_mod, device_mod}; + return out; } /*! @@ -478,52 +470,86 @@ runtime::Module TIRToRuntime(const Map& inputs_arg, // Update target host for all targets CheckAndUpdateHostConsistency(&inputs, &target_host); - // Take the attrs from the first module so the eventual modules have them. - // Ideally this would just be one unified module all the way through; - IRModule first_module = (*inputs.begin()).second; - IRModule mhost_all = IRModule(Map(), {}, {}, {}, first_module->attrs); - - ICHECK(mhost_all.defined()) << "The host module must be defined"; - - for (const auto& it : inputs) { - if (it.second.defined()) { - const Target& target = it.first; - const IRModule& ir_module = it.second; - auto pair = SplitMixedModule(ir_module, target, target_host); - auto& host_mod = pair.first; - auto& device_mod = pair.second; - - ICHECK(host_mod.defined()) << "The split host module must be defined"; - - ICHECK(mhost_all.defined()) << "The host module must be defined"; - - // We don't want library modules going back into host codegen - // unless they're supposed to. Here if we overrode the target host - // to allow lowering previously we check that it's meant to be placed - // back into the host Module. - bool overrides_host_target = - target->GetTargetDeviceType() == target_host->GetTargetDeviceType(); - bool non_host_target_kind = target->kind != target_host->kind; - if (overrides_host_target && non_host_target_kind) { - device_modules.push_back(codegen::Build(host_mod, it.first)); - } else { - mhost_all->Update(host_mod); + auto has_gpu_function = [](const IRModule& mod) -> bool { + for (const auto& [gvar, func] : mod->functions) { + if (auto target = func->GetAttr(tvm::attr::kTarget)) { + if (target.value()->HasKey("gpu")) { + return true; + } } + } + return false; + }; + + IRModule merged = MergeModules(inputs); + + bool contains_gpu_function_pre = has_gpu_function(merged); + merged = MixedModulePassManager(merged)(merged); + bool contains_gpu_function_post = has_gpu_function(merged); + if (contains_gpu_function_pre && !contains_gpu_function_post) { + DLOG(WARNING) << "Specified GPU targets, " + << "but cannot find device code. Did you forget to bind?"; + } + + Map split = SplitModule(merged); - if (device_mod->functions.size() != 0) { - device_modules.push_back(codegen::Build(device_mod, it.first)); + Map built; + for (const auto& [target, mod] : split) { + built.Set(target, codegen::Build(mod, target)); + } + + auto host_target = [&]() -> Target { + // All targets that contain a kIsEntryFunc=True function + Array targets_with_entry_func; + + // All targets that can run on the CPU and contain at least one + // function without kIsEntryFunc=False. + Array cpu_targets; + for (const auto& [target, mod] : split) { + bool contains_entry_func = false; + bool may_contain_entry_func = false; + for (const auto& [gvar, func] : mod->functions) { + Optional is_entry_func = func->attrs.GetAttr(tvm::tir::attr::kIsEntryFunc); + if (is_entry_func.defined() && is_entry_func.value()->value) { + contains_entry_func = true; + } else if (!is_entry_func.defined()) { + may_contain_entry_func = true; + } + } + + if (contains_entry_func) { + targets_with_entry_func.push_back(target); + } + + if (may_contain_entry_func && target->HasKey("cpu")) { + cpu_targets.push_back(target); } } - } - runtime::Module mhost = codegen::Build(mhost_all, target_host); - for (const auto& it : device_modules) { - if (it.operator->()) { - mhost.Import(it); + if (targets_with_entry_func.size()) { + ICHECK_EQ(targets_with_entry_func.size(), 1) + << "Expected at most one function " + << "annotated with tvm::tir::attr::kIsEntryFunc " + << "(\"" << tvm::tir::attr::kIsEntryFunc << "\"), " + << "but found: " << targets_with_entry_func; + return targets_with_entry_func[0]; + } else if (cpu_targets.size() == 1) { + return cpu_targets[0]; + } else { + LOG(FATAL) << "Could not determine which target is the host. " + << "No function was annotated with tvm::tir::attr::kIsEntryFunc (\"" + << tvm::tir::attr::kIsEntryFunc << "\"), " + << "and " << cpu_targets.size() << " targets have the 'cpu' key"; + } + }(); + + auto runtime_module = built[host_target]; + for (const auto& [target, mod] : built) { + if (!mod.same_as(runtime_module)) { + runtime_module.Import(mod); } } - - return mhost; + return runtime_module; } TVM_REGISTER_GLOBAL("driver.tir_to_runtime") @@ -564,13 +590,16 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, return TIRToRuntime(inputs, target_host); } -transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { +transform::Sequential MixedModulePassManager(IRModule mixed_mod, Optional target) { transform::PassContext pass_ctx = transform::PassContext::Current(); Array mixed_pass_list; - // FPComputeLegalize uses the target attrs added by BindTarget, so it must come first - mixed_pass_list.push_back(tir::transform::BindTarget(target)); + // FPComputeLegalize uses the target attrs added by BindTarget, so + // BindTarget must come first. + if (target) { + mixed_pass_list.push_back(tir::transform::BindTarget(target.value())); + } mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize()); // VerifyVTCMLimit must occur before LowerVtcmAlloc @@ -625,7 +654,28 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch()); - return transform::Sequential(mixed_pass_list); + // Only applies to the device functions, identified by inspection of + // each function's tvm::attr::kTarget attribute. + mixed_pass_list.push_back(tir::transform::LowerWarpMemory()); + + // Only applies to the host functions, identified by inspection of + // each function's tvm::attr::kTarget attribute. + mixed_pass_list.push_back(tir::transform::LowerTVMBuiltin()); + + // Apply to both host and device functions + mixed_pass_list.push_back(tir::transform::Simplify()); + mixed_pass_list.push_back(tir::transform::LowerCustomDatatypes()); + mixed_pass_list.push_back(tir::transform::LowerIntrin()); + mixed_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); + + // Only applies to the host functions, identified by inspection of + // each function's tvm::attr::kTarget attribute. + mixed_pass_list.push_back(tir::transform::CombineContextCall()); + if (pass_ctx->GetConfig("tir.enable_debug", Bool(false)).value()) { + mixed_pass_list.push_back(tir::transform::InstallDebugSpans()); + } + + return transform::Sequential(mixed_pass_list, "tvm.build"); } TVM_REGISTER_GLOBAL("driver.mixed_mod_passes") @@ -634,6 +684,10 @@ TVM_REGISTER_GLOBAL("driver.mixed_mod_passes") }); transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) { + LOG(WARNING) << "Use of driver.host_mod_passes is deprecated. " + << "All lowering passes are now included " + << "as part of driver.mixed_mod_passes."; + transform::PassContext pass_ctx = transform::PassContext::Current(); bool enable_debug = pass_ctx->GetConfig("tir.enable_debug", Bool(false)).value(); @@ -659,7 +713,7 @@ transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_ho host_pass_list.push_back(tir::transform::InstallDebugSpans()); } - return transform::Sequential(host_pass_list); + return transform::Sequential(host_pass_list, "tir.host_mod_passes"); } TVM_REGISTER_GLOBAL("driver.host_mod_passes") @@ -668,6 +722,10 @@ TVM_REGISTER_GLOBAL("driver.host_mod_passes") }); transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) { + LOG(WARNING) << "Use of driver.device_mod_passes is deprecated. " + << "All lowering passes are now included " + << "as part of driver.mixed_mod_passes."; + Array device_pass_list; runtime::TypedPackedFunc fcond = [](const tir::PrimFunc& f) { return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == @@ -683,7 +741,7 @@ transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); device_pass_list.push_back(tir::transform::LowerIntrin()); - return transform::Sequential(device_pass_list); + return transform::Sequential(device_pass_list, "tir.device_mod_passes"); } TVM_REGISTER_GLOBAL("driver.device_mod_passes") diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index 2b037181653c2..90c0fd41dc7c0 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -64,7 +64,7 @@ class ConvertAddToSubtract : public MixedModeMutator { explicit ConvertAddToSubtract(IRModule ir_module, Target host_target) : ir_module_(ir_module), host_target_(host_target), - custom_target_(Target("example_target_hook")) {} + custom_target_(Target(Target("example_target_hook"), Target("example_target_hook"))) {} IRModule Mutate() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index c332314a3e6c7..b0b0440de58ee 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -610,12 +610,16 @@ void* LLVMModuleNode::GetFunctionAddr(const std::string& name, return nullptr; } -TVM_REGISTER_GLOBAL("target.build.llvm") - .set_body_typed([](IRModule mod, Target target) -> runtime::Module { - auto n = make_object(); - n->Init(mod, target); - return runtime::Module(n); - }); +namespace { +runtime::Module BuildLLVM(IRModule mod, Target target) { + auto n = make_object(); + n->Init(mod, target); + return runtime::Module(n); +} +} // namespace + +TVM_REGISTER_GLOBAL("target.build.llvm").set_body_typed(BuildLLVM); +TVM_REGISTER_GLOBAL("target.build.ext_dev").set_body_typed(BuildLLVM); TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") .set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module { diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 3e013492efc23..297b5d2ad8a66 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -90,7 +90,15 @@ class CodeGenCHost : public CodeGenC { Array function_names_; /*! \brief whether to emit asserts in the resulting C code */ bool emit_asserts_; - /*! \brief whether to emit forwared function declarations in the resulting C code */ + /*! \brief whether to emit forwared function declarations in the resulting C code + * + * Determines the behavior when encountering an unknown symbol as + * the callee in a `CallNode` whose operation is + * `builtin::call_extern`. If true, the unknown symbol will be + * forward-declared as a function, derived from the TIR types of + * CallNode's argument/return value. If false, the forward + * declaration is omitted. + */ bool emit_fwd_func_decl_; FunctionInfo GetFunctionInfo(const CallNode* op, bool has_resource_handle); diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 708d3ccd7621a..1d227b42d1be5 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -430,7 +430,7 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU) // line break .set_default_keys({"cpu"}); -TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev); +TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev).set_default_keys({"cpu"}); TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 7f47e660625b5..730b5d6689e96 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -97,6 +97,11 @@ Type GetType(const PrimExpr& expr) { return PointerType(PrimType(address->dtype)); } } + + if (expr.as()) { + return PointerType(PrimType(DataType::Int(8))); + } + // Default: return the type indicated by the dtype. runtime::DataType dtype = expr.dtype(); return GetTypeFromRuntimeDataType(dtype); diff --git a/src/tir/transforms/annotate_device_regions.cc b/src/tir/transforms/annotate_device_regions.cc index a81af7d7805bf..4c5209e4b3b35 100644 --- a/src/tir/transforms/annotate_device_regions.cc +++ b/src/tir/transforms/annotate_device_regions.cc @@ -29,31 +29,127 @@ #include #include +#include +#include +#include + namespace tvm { namespace tir { -class DeviceRegionAnnotater : public StmtMutator { +class DeviceRegionAnnotater : public StmtExprMutator { + using Parent = StmtExprMutator; + public: + static Stmt Apply(Target host_target, Target device_target, Stmt body) { + bool same_host_and_device = host_target->str() == device_target->str(); + if (same_host_and_device) { + return body; + } + + DeviceRegionAnnotater mutator(device_target); + body = mutator(body); + + // If no region was found that must be on the device, but the + // device and host differ (e.g. `T.target('c', host='llvm')`), + // then the entire region should be annotated. This preserves the + // host-side handling of DLTensor arguments, while ensuring that + // any device targets are used for the codegen. + if (mutator.current_region_ == Region::Either && !same_host_and_device) { + body = AttrStmt(device_target, tvm::attr::kTarget, 0, body); + } + + return body; + } + + private: explicit DeviceRegionAnnotater(Target device_target) : device_target_(device_target) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == tvm::attr::kTarget) { // If a target attribute already exists, use it as-is. + current_region_ = Region::Device; return GetRef(op); } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || op->attr_key == attr::device_scope) { // These attributes are only allowed in device-side code, so // they should be annotated with the function's default target. + current_region_ = Region::Device; Stmt body = GetRef(op); return AttrStmt(device_target_, tvm::attr::kTarget, 0, body); } else { // All other annotations are ignored - return StmtMutator::VisitStmt_(op); + return Parent::VisitStmt_(op); } } - private: + Stmt VisitStmt_(const SeqStmtNode* op) final { + std::vector regions; + Array seq = op->seq.Map([&](Stmt stmt) { + current_region_ = Region::Either; + stmt = VisitStmt(stmt); + regions.push_back(current_region_); + return stmt; + }); + + bool has_host_function = std::any_of(regions.begin(), regions.end(), + [](const auto& reg) { return reg == Region::Host; }); + if (has_host_function) { + current_region_ = Region::Host; + + Array new_seq; + Array device_seq; + auto finish_device_seq = [&]() { + if (device_seq.size()) { + new_seq.push_back( + AttrStmt(device_target_, tvm::attr::kTarget, 0, SeqStmt::Flatten(device_seq))); + device_seq.clear(); + } + }; + + for (size_t i = 0; i < seq.size(); i++) { + if (regions[i] == Region::Host) { + finish_device_seq(); + new_seq.push_back(seq[i]); + } else { + device_seq.push_back(seq[i]); + } + } + finish_device_seq(); + + return SeqStmt::Flatten(new_seq); + } else if (seq.same_as(op->seq)) { + return GetRef(op); + } else { + return SeqStmt(seq); + } + } + + PrimExpr VisitExpr_(const CallNode* op) final { + // TODO(Lunderberg): Make a new attribute in builtin.cc to label + // host-only operations. + bool is_host_only_op = + op->op.same_as(builtin::tvm_call_packed()) || op->op.same_as(builtin::tvm_call_cpacked()) || + op->op.same_as(builtin::tvm_call_packed_lowered()) || + op->op.same_as(builtin::tvm_call_cpacked_lowered()) || + op->op.same_as(builtin::tvm_struct_get()) || op->op.same_as(builtin::tvm_struct_set()) || + op->op.same_as(builtin::tvm_throw_last_error()) || + op->op.same_as(builtin::tvm_stack_alloca()) || + op->op.same_as(builtin::tvm_stack_make_shape()) || + op->op.same_as(builtin::tvm_stack_make_array()); + if (is_host_only_op) { + current_region_ = Region::Host; + } + return Parent::VisitExpr_(op); + } + Target device_target_; + + enum class Region { + Either, + Host, + Device, + }; + Region current_region_{Region::Either}; }; namespace transform { @@ -64,9 +160,12 @@ Pass AnnotateDeviceRegions() { ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute"; Target target = opt_target.value(); - if (target->GetHost()) { - DeviceRegionAnnotater mutator(target.WithoutHost()); - func.CopyOnWrite()->body = mutator(func->body); + if (auto opt_host = target->GetHost()) { + auto new_body = + DeviceRegionAnnotater::Apply(opt_host.value(), target.WithoutHost(), func->body); + if (!new_body.same_as(func->body)) { + func.CopyOnWrite()->body = new_body; + } } return func; }; diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index 932116485fa1d..a33376bd69eeb 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -260,7 +260,9 @@ class DeviceKernelMutator : public StmtExprMutator { bool same_device_type = caller_target->GetTargetDeviceType() == callee_target->GetTargetDeviceType(); - if (same_device_type) { + bool linkable_module = (caller_target->GetTargetDeviceType() == kDLCPU) && + (callee_target->GetTargetDeviceType() == kDLExtDev); + if (same_device_type || linkable_module) { // Calls to another target using the same device (e.g. LLVM // calling a custom TIRToRuntime target) do not require a kernel // launch, but need to be replaced with call_extern. diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 212ccf6e56169..fbc5d4fda92dc 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -44,6 +44,10 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "") : IRMutatorWithAnalyzer(analyzer) { + if (target == "ext_dev") { + target = "llvm"; + } + std::vector patterns; patterns.push_back(target + ".FLowerIntrinsic"); patterns.push_back(target + ".FLegalize"); diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index c90384fea73a9..abc7ce91efb0d 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -97,7 +97,8 @@ class HostDeviceSplitter : public StmtMutator { PrimFunc device_func(params, body, kernel_ret_type); device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, {tir::attr::kNoAlias, Bool(true)}, - {tir::attr::kIsGlobalFunc, Bool(true)}}); + {tir::attr::kIsGlobalFunc, Bool(true)}, + {tir::attr::kIsEntryFunc, Bool(false)}}); GlobalVar kernel_symbol_global = var_supply_(); (*device_mod_)->Add(kernel_symbol_global, device_func); diff --git a/tests/python/tir-transform/test_tir_transform_annotate_device_regions.py b/tests/python/tir-transform/test_tir_transform_annotate_device_regions.py index efa43027e9c64..7b869ddf7694b 100644 --- a/tests/python/tir-transform/test_tir_transform_annotate_device_regions.py +++ b/tests/python/tir-transform/test_tir_transform_annotate_device_regions.py @@ -54,5 +54,76 @@ def expected(A: T.Buffer(1, "float32")): A[0] = 0.0 +class TestAnnotateEntireBody(BaseCompare): + """Annotation inserted to wrap entire function + + Function is assumed to belong on the device. + """ + + def before(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + A[0] = 0.0 + + def expected(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.attr(T.target("cuda"), "target", 0) + A[0] = 0.0 + + +class TestNoAnnotationForSameHostDevice(BaseCompare): + """No annotation is needed if host/device are the same""" + + def before(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("llvm", host="llvm")}) + A[0] = 0.0 + + expected = before + + +class TestAnnotationAvoidsHostConstructs(BaseCompare): + """Device annotation does not contain host-only functions + + Calls that must be on the host side (e.g. T.call_packed) remain on + the host. + """ + + def before(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.call_packed("dummy_function", dtype="void") + A[0] = 0.0 + T.call_packed("dummy_function", dtype="void") + + def expected(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.call_packed("dummy_function", dtype="void") + with T.attr(T.target("cuda"), "target", 0): + A[0] = 0.0 + T.call_packed("dummy_function", dtype="void") + + +class TestAnnotationNoRepetition(BaseCompare): + """Device annotation does not contain host-only functions + + When placing everything that isn't a host-specific function into + target block, sequential device statements should be in the same + block. + """ + + def before(A: T.Buffer(2, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.call_packed("dummy_function", dtype="void") + A[0] = 0.0 + A[1] = 1.0 + T.call_packed("dummy_function", dtype="void") + + def expected(A: T.Buffer(2, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.call_packed("dummy_function", dtype="void") + with T.attr(T.target("cuda"), "target", 0): + A[0] = 0.0 + A[1] = 1.0 + T.call_packed("dummy_function", dtype="void") + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_split_host_device.py b/tests/python/tir-transform/test_tir_transform_split_host_device.py index 6adfbeb81d54b..a05ac74f84883 100644 --- a/tests/python/tir-transform/test_tir_transform_split_host_device.py +++ b/tests/python/tir-transform/test_tir_transform_split_host_device.py @@ -122,6 +122,7 @@ def main_kernel(n: T.int32): "target": T.target("cuda"), "tir.noalias": T.bool(True), "tir.is_global_func": True, + "tir.is_entry_func": False, } ) T.evaluate(n) @@ -159,6 +160,7 @@ def main_kernel(n: T.int32) -> T.int32: "target": T.target("llvm"), "tir.noalias": T.bool(True), "tir.is_global_func": True, + "tir.is_entry_func": False, } ) T.evaluate(n) @@ -200,6 +202,7 @@ def main_kernel(n: T.int32): "target": T.target("cuda"), "tir.noalias": T.bool(True), "tir.is_global_func": True, + "tir.is_entry_func": False, } ) T.evaluate(n) @@ -261,6 +264,7 @@ def main_kernel_1(n: T.int32): "target": T.target("cuda"), "tir.noalias": T.bool(True), "tir.is_global_func": True, + "tir.is_entry_func": False, } ) T.evaluate(n) diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 1e595c8441b25..7c02f356c6e53 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -202,7 +202,14 @@ def _post_order(op): ), op.body, ) - alloc = tvm.tir.Allocate(buffer_var, op.dtype, op.extents, op.condition, let_stmt) + alloc = tvm.tir.Allocate( + buffer_var, + op.dtype, + op.extents, + op.condition, + let_stmt, + annotations={"disable_lower_builtin": True}, + ) del var_remap[buffer_var] bufs_to_delete = [ old_buf for old_buf in buf_remap if old_buf.data.same_as(buffer_var) diff --git a/vta/scripts/tune_resnet.py b/vta/scripts/tune_resnet.py new file mode 100644 index 0000000000000..3f5c693b78a07 --- /dev/null +++ b/vta/scripts/tune_resnet.py @@ -0,0 +1,373 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Perform ResNet autoTVM tuning on VTA using Relay.""" + +import argparse, os, time +from mxnet.gluon.model_zoo import vision +import numpy as np +from PIL import Image + +from tvm import topi +import tvm +from tvm import te +from tvm import rpc, autotvm, relay +from tvm.autotvm.measure.measure_methods import request_remote +from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner +from tvm.contrib import graph_executor, utils, download +from tvm.contrib.debugger import debug_executor +import vta +from vta.testing import simulator +from vta.top import graph_pack +from tvm.autotvm.task import extract_from_program + + +def parse_arguments(): + + parser = argparse.ArgumentParser(description="Train a model for image classification.") + parser.add_argument( + "--model", + type=str, + default="resnet18_v1", + choices=["resnet18_v1"], + help="Input model name.", + ) + parser.add_argument( + "--start-name", + type=str, + default="nn.max_pool2d", + help="The name of the node where packing starts", + ) + parser.add_argument( + "--stop-name", + type=str, + default="nn.global_avg_pool2d", + help="The name of the node where packing stops", + ) + parser.add_argument( + "--debug-profile", action="store_true", help="Show layer-wise time cost profiling results" + ) + parser.add_argument( + "--device", default="vta", choices=["vta", "arm_cpu"], help="Select device target" + ) + parser.add_argument( + "--measurements", type=int, default=1, help="Number of measurements during AutoTVM search" + ) + parser.add_argument("--tuner", type=str, default="random", help="AutoTVM search strategy") + parser.add_argument( + "--log-filename", type=str, default="resnet-18.log", help="AutoTVM log file name" + ) + + return parser.parse_args() + + +def register_vta_tuning_tasks(): + from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args + + @tvm.te.tag_scope(tag=topi.tag.ELEMWISE) + def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.tir.const(a_min, x.dtype) + const_max = tvm.tir.const(a_max, x.dtype) + x = te.compute(x.shape, lambda *i: tvm.te.min(x(*i), const_max), name="clipA") + x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB") + return x + + # init autotvm env to register VTA operator + TaskExtractEnv() + + @autotvm.task.register("topi_nn_conv2d", override=True) + def _topi_nn_conv2d(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + + with tvm.target.vta(): + res = topi.nn.conv2d(*args, **kwargs) + res = topi.right_shift(res, 8) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + if tvm.target.Target.current().device_name == "vta": + s = topi.generic.schedule_conv2d_nchw([res]) + else: + s = te.create_schedule([res.op]) + return s, [A, W, res] + + @autotvm.task.register("topi_nn_dense", override=True) + def _topi_nn_dense(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + + with tvm.target.vta(): + res = topi.nn.dense(*args, **kwargs) + res = topi.right_shift(res, 8) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + if tvm.target.Target.current().device_name == "vta": + s = topi.generic.schedule_dense([res]) + else: + s = te.create_schedule([res.op]) + + return s, [A, W, res] + + +def compile_network(opt, env, target): + + # Populate the shape and data type dictionary + dtype_dict = {"data": "float32"} + shape_dict = {"data": (env.BATCH, 3, 224, 224)} + + # Get off the shelf gluon model, and convert to relay + gluon_model = vision.get_model(opt.model, pretrained=True) + mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict) + + # Update shape and type dictionary + shape_dict.update({k: v.shape for k, v in params.items()}) + dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) + + # Perform quantization in Relay + # Note: We set opt_level to 3 in order to fold batch norm + with tvm.transform.PassContext(opt_level=3): + with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]): + relay_prog = relay.quantize.quantize(mod["main"], params=params) + + # Perform graph packing and constant folding for VTA target + if target.device_name == "vta": + assert env.BLOCK_IN == env.BLOCK_OUT + relay_prog = graph_pack( + relay_prog, + env.BATCH, + env.BLOCK_OUT, + env.WGT_WIDTH, + start_name=opt.start_name, + stop_name=opt.stop_name, + ) + + return relay_prog, params + + +def tune_tasks( + tasks, + measure_option, + tuner="xgb", + n_trial=1000, + early_stopping=None, + log_filename="tuning.log", + use_transfer_learning=True, + try_winograd=True, +): + + # create tmp log file + tmp_log_file = log_filename + ".tmp" + if os.path.exists(tmp_log_file): + os.remove(tmp_log_file) + + for i, tsk in enumerate(reversed(tasks)): + prefix = "[Task %2d/%2d] " % (i + 1, len(tasks)) + + # create tuner + if tuner == "xgb": + tuner_obj = XGBTuner(tsk, loss_type="reg") + elif tuner == "xgb_knob": + tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="knob") + elif tuner == "xgb_itervar": + tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="itervar") + elif tuner == "xgb_curve": + tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="curve") + elif tuner == "xgb_rank": + tuner_obj = XGBTuner(tsk, loss_type="rank") + elif tuner == "xgb_rank_knob": + tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="knob") + elif tuner == "xgb_rank_itervar": + tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="itervar") + elif tuner == "xgb_rank_curve": + tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="curve") + elif tuner == "xgb_rank_binary": + tuner_obj = XGBTuner(tsk, loss_type="rank-binary") + elif tuner == "xgb_rank_binary_knob": + tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="knob") + elif tuner == "xgb_rank_binary_itervar": + tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="itervar") + elif tuner == "xgb_rank_binary_curve": + tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="curve") + elif tuner == "ga": + tuner_obj = GATuner(tsk, pop_size=50) + elif tuner == "random": + tuner_obj = RandomTuner(tsk) + elif tuner == "gridsearch": + tuner_obj = GridSearchTuner(tsk) + else: + raise ValueError("Invalid tuner: " + tuner) + + if use_transfer_learning: + if os.path.isfile(tmp_log_file): + tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) + + # do tuning + n_trial_ = min(n_trial, len(tsk.config_space)) + tuner_obj.tune( + n_trial_, + early_stopping=early_stopping, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(n_trial_, prefix=prefix), + autotvm.callback.log_to_file(tmp_log_file), + ], + ) + + # pick best records to a cache file + autotvm.record.pick_best(tmp_log_file, log_filename) + os.remove(tmp_log_file) + + +if __name__ == "__main__": + + opt = parse_arguments() + + # Make sure that TVM was compiled with RPC=1 + assert tvm.runtime.enabled("rpc") + + # Read in VTA environment + env = vta.get_env() + + # Get remote from fleet node + tracker_host = os.environ.get("TVM_TRACKER_HOST", None) + tracker_port = os.environ.get("TVM_TRACKER_PORT", None) + if not tracker_host or not tracker_port: + print("Set your AutoTVM tracker node host and port variables to run the autotuner") + exit() + + # Get remote + if env.TARGET != "sim": + + # Measure build start time + reconfig_start = time.time() + + # Get remote from fleet node + remote = autotvm.measure.request_remote( + env.TARGET, tracker_host, int(tracker_port), timeout=10000 + ) + + # Reconfigure the JIT runtime and FPGA. + # You can program the FPGA with your own custom bitstream + # by passing the path to the bitstream file instead of None. + vta.reconfig_runtime(remote) + vta.program_fpga(remote, bitstream=None) + + # Report on reconfiguration time + reconfig_time = time.time() - reconfig_start + print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time)) + + # In simulation mode, host the RPC server locally. + else: + remote = rpc.LocalSession() + + # VTA target and execution context + target = env.target if opt.device == "vta" else env.target_vta_cpu + ctx = remote.ext_dev(0) if opt.device == "vta" else remote.cpu(0) + + # Compile Relay program + print("Initial compile...") + relay_prog, params = compile_network(opt, env, target) + + # Register VTA tuning tasks + register_vta_tuning_tasks() + + # Perform task extraction on Relay program + print("Extracting tasks...") + tasks = extract_from_program( + func=relay_prog, + params=params, + ops=(relay.op.get("nn.conv2d"),), + target=tvm.target.Target(target, host=env.target_host), + ) + + # Perform Autotuning + print("Tuning...") + tuning_opt = { + "log_filename": opt.log_filename, + "tuner": opt.tuner, + "n_trial": 1e9, + "early_stopping": None, + "measure_option": autotvm.measure_option( + builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func), + runner=autotvm.RPCRunner( + env.TARGET, + tracker_host, + tracker_port, + number=4, + min_repeat_ms=150, + repeat=opt.measurements, + timeout=60, + # check_correctness=True, # TODO: re-enable when check_correctness works again. + ), + ), + } + tune_tasks(tasks, **tuning_opt) + + # Compile kernels with history best records + with autotvm.tophub.context(target, extra_files=[opt.log_filename]): + + # Compile network + print("Compiling network with best tuning parameters...") + if target.device_name != "vta": + with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): + graph, lib, params = relay.build( + relay_prog, + target=tvm.target.Target(target, host=env.target_host), + params=params, + ) + else: + with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + graph, lib, params = relay.build( + relay_prog, + target=tvm.target.Target(target, host=env.target_host), + params=params, + ) + + # Export library + temp = utils.tempdir() + lib.export_library(temp.relpath("graphlib.so")) + remote.upload(temp.relpath("graphlib.so")) + lib = remote.load_module("graphlib.so") + + # If detailed runtime info is needed build with debug runtime + if opt.debug_profile: + m = debug_executor.create(graph, lib, ctx) + else: + m = graph_executor.create(graph, lib, ctx) + + # Set the network parameters and synthetic input + image = tvm.nd.array((np.random.uniform(size=(1, 3, 224, 224))).astype("float32")) + m.set_input(**params) + m.set_input("data", image) + + # Perform inference + timer = m.module.time_evaluator("run", ctx, number=4, repeat=opt.measurements) + tcost = timer() + prof_res = np.array(tcost.results) * 1000 # convert to millisecond + print( + "Mean inference time (std dev): %.2f ms (%.2f ms)" + % (np.mean(prof_res), np.std(prof_res)) + ) + + # Display profile information + if opt.debug_profile: + m.run() diff --git a/vta/tutorials/matrix_multiply.py b/vta/tutorials/matrix_multiply.py index 0d11678544581..1d1dd98dfaf3c 100644 --- a/vta/tutorials/matrix_multiply.py +++ b/vta/tutorials/matrix_multiply.py @@ -392,13 +392,13 @@ # Write the compiled module into an object file. temp = utils.tempdir() -my_gemm.save(temp.relpath("gemm.o")) +my_gemm.export_library(temp.relpath("gemm.so")) # Send the executable over RPC -remote.upload(temp.relpath("gemm.o")) +remote.upload(temp.relpath("gemm.so")) # Load the compiled module -f = remote.load_module("gemm.o") +f = remote.load_module("gemm.so") ###################################################################### # Running the Function diff --git a/vta/tutorials/optimize/convolution_opt.py b/vta/tutorials/optimize/convolution_opt.py index 521a73ab510d7..3c757fdc0c2b5 100644 --- a/vta/tutorials/optimize/convolution_opt.py +++ b/vta/tutorials/optimize/convolution_opt.py @@ -374,9 +374,9 @@ s, [data, kernel, res], tvm.target.Target("ext_dev", host=env.target_host), name="my_conv" ) temp = utils.tempdir() -my_conv.save(temp.relpath("conv2d.o")) -remote.upload(temp.relpath("conv2d.o")) -f = remote.load_module("conv2d.o") +my_conv.export_library(temp.relpath("conv2d.so")) +remote.upload(temp.relpath("conv2d.so")) +f = remote.load_module("conv2d.so") # Get the remote device context ctx = remote.ext_dev(0) diff --git a/vta/tutorials/optimize/matrix_multiply_opt.py b/vta/tutorials/optimize/matrix_multiply_opt.py index b470475b16e7b..ea70b5260c561 100644 --- a/vta/tutorials/optimize/matrix_multiply_opt.py +++ b/vta/tutorials/optimize/matrix_multiply_opt.py @@ -314,9 +314,9 @@ s, [data, weight, res], tvm.target.Target("ext_dev", host=env.target_host), name="my_gemm" ) temp = utils.tempdir() -my_gemm.save(temp.relpath("gemm.o")) -remote.upload(temp.relpath("gemm.o")) -f = remote.load_module("gemm.o") +my_gemm.export_library(temp.relpath("gemm.so")) +remote.upload(temp.relpath("gemm.so")) +f = remote.load_module("gemm.so") # Get the remote device context ctx = remote.ext_dev(0) diff --git a/vta/tutorials/vta_get_started.py b/vta/tutorials/vta_get_started.py index 3482258dece89..6edb34184fb42 100644 --- a/vta/tutorials/vta_get_started.py +++ b/vta/tutorials/vta_get_started.py @@ -327,17 +327,17 @@ # Write the compiled module into an object file. temp = utils.tempdir() -my_vadd.save(temp.relpath("vadd.o")) +my_vadd.export_library(temp.relpath("vadd.so")) # Send the executable over RPC -remote.upload(temp.relpath("vadd.o")) +remote.upload(temp.relpath("vadd.so")) ###################################################################### # Loading the Module # ~~~~~~~~~~~~~~~~~~ # We can load the compiled module from the file system to run the code. -f = remote.load_module("vadd.o") +f = remote.load_module("vadd.so") ###################################################################### # Running the Function