From 0b91e539ef788489511d94f60605ffea3d32b6c9 Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Mon, 11 Oct 2021 14:44:02 +0100 Subject: [PATCH] Initial Implementation of TIRToRuntime Target hook (#9190) * Initial Implementation of TIRToRuntime Target hook This is the initial implementation which wires in a test case for TIRToRuntime, in order to get this working I re-used `CodegenCHost` as it implements all of the `Op`s required from the lowered `PrimFunc`. Currently, the `IRModule` is non-unified but in future work it should definitely do so, I wanted to implement the basics here to get the infra in place. * Fix heterogeneous compute with multiple kDLCPU targets * Remove rogue te_compiler.h include --- .../modules/contrib/ExampleTargetHooks.cmake | 2 +- include/tvm/target/target_kind.h | 28 ++++++++ src/driver/driver_api.cc | 28 +++++++- .../example_target_hooks/relay_to_tir.cc | 19 ++++-- .../contrib/example_target_hooks/target.cc | 39 +++++++++++ .../example_target_hooks/tir_to_runtime.cc | 64 +++++++++++++++++++ src/target/codegen.cc | 5 ++ src/target/source/codegen_c_host.h | 2 +- tests/python/relay/test_target_hooks.py | 23 +++++++ 9 files changed, 199 insertions(+), 11 deletions(-) create mode 100644 src/relay/backend/contrib/example_target_hooks/target.cc create mode 100644 src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc diff --git a/cmake/modules/contrib/ExampleTargetHooks.cmake b/cmake/modules/contrib/ExampleTargetHooks.cmake index eb53dda133d2..e9003b02103e 100644 --- a/cmake/modules/contrib/ExampleTargetHooks.cmake +++ b/cmake/modules/contrib/ExampleTargetHooks.cmake @@ -15,5 +15,5 @@ # specific language governing permissions and limitations # under the License. -file(GLOB EXAMPLE_TARGET_HOOKS_SRC src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc) +file(GLOB EXAMPLE_TARGET_HOOKS_SRC src/relay/backend/contrib/example_target_hooks/*.cc) list(APPEND COMPILER_SRCS ${EXAMPLE_TARGET_HOOKS_SRC}) diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 8a2bbcbd0121..9d8695a43aff 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -24,6 +24,7 @@ #ifndef TVM_TARGET_TARGET_KIND_H_ #define TVM_TARGET_TARGET_KIND_H_ +#include #include #include @@ -33,6 +34,33 @@ #include namespace tvm { + +class Target; + +/*! + * \brief RelayToTIR tvm::transform::Pass specific to a TargetKind + * + * Called before the default lowering passes. + * + * \param mod The module that an optimization pass runs on. + * \param pass_ctx The pass context that can provide information for the optimization. + * + * \return The transformed module. + */ +using FTVMRelayToTIR = transform::Pass; + +/*! + * \brief TIRToRuntime conversion specific to a TargetKind + * + * This function is responsible for scanning an IRModule for appropriate Target-specific functions + and generating a Runtime module representing the compiled output + * + * \param ir_module Unified IRModule + * \param target Target to filter on or retrieve arguments from + * \return Runtime Module containing compiled functions + */ +using FTVMTIRToRuntime = runtime::TypedPackedFunc; + namespace detail { template struct ValueTypeInfoMaker; diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index bfea3e7b67c0..2c6fbc2eb76d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -401,12 +401,21 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target auto opt_mixed = transform::Sequential(mixed_pass_list); mod_mixed = opt_mixed(std::move(mod_mixed)); + // We make an assumption here that the overriden host target + // can be used alongside the default host codegen based on device type + // this is so the correct code generator is used later instead of overriding the target. + // We need better support for inserting multiple kDLCPU targets as our current options + // are kDeviceKernelLaunch or not + Target overriden_host_target = target_host; + if (target->kind->device_type == target_host->kind->device_type) { + overriden_host_target = target; + } auto host_pass_list = { Filter([](const tir::PrimFunc& f) { return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch; }), - BindTarget(target_host), + BindTarget(overriden_host_target), tir::transform::LowerTVMBuiltin(), tir::transform::LowerCustomDatatypes(), tir::transform::LowerIntrin(), @@ -487,7 +496,9 @@ runtime::Module build(const Map& inputs_arg, const Target& tar for (const auto& it : inputs) { if (it.second.defined()) { - auto pair = SplitDevHostFuncs(it.second, it.first, target_host, pass_ctx); + const Target& target = it.first; + const IRModule& ir_module = it.second; + auto pair = SplitDevHostFuncs(ir_module, target, target_host, pass_ctx); auto& mhost = pair.first; auto& mdevice = pair.second; @@ -495,7 +506,17 @@ runtime::Module build(const Map& inputs_arg, const Target& tar ICHECK(mhost_all.defined()) << "The host module must be defined"; - mhost_all->Update(mhost); + // We don't want library modules going back into host codegen + // unless they're supposed to. Here if we overrode the target host + // to allow lowering previously we check that it's meant to be placed + // back into the host Module. + bool overrides_host_target = target->kind->device_type == target_host->kind->device_type; + bool non_host_target_kind = target->kind != target_host->kind; + if (overrides_host_target && non_host_target_kind) { + device_modules.push_back(codegen::Build(mhost, it.first)); + } else { + mhost_all->Update(mhost); + } if (mdevice->functions.size() != 0) { device_modules.push_back(codegen::Build(mdevice, it.first)); @@ -510,6 +531,7 @@ runtime::Module build(const Map& inputs_arg, const Target& tar mhost.Import(it); } } + return mhost; } diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index 6d332803041d..cae20210ec4f 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -33,7 +33,9 @@ namespace example_target_hooks { class ConvertAddToSubtract : public MixedModeMutator { public: explicit ConvertAddToSubtract(IRModule ir_module, Target host_target) - : ir_module_(ir_module), host_target_(host_target) {} + : ir_module_(ir_module), + host_target_(host_target), + custom_target_(Target("example_target_hook")) {} IRModule Mutate() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); @@ -81,7 +83,15 @@ class ConvertAddToSubtract : public MixedModeMutator { tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(), buffer_map, DictAttrs(dict_attrs)); - replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_); + + // Switch to TIRToRuntime hook for testing + Bool tir_to_runtime = func->GetAttr("tir_to_runtime").value_or(Bool(false)); + if (tir_to_runtime) { + replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, custom_target_); + } else { + replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_); + } + ir_module_->Add(new_global_var, replacement_func); } @@ -109,6 +119,7 @@ class ConvertAddToSubtract : public MixedModeMutator { public: IRModule ir_module_; Target host_target_; + Target custom_target_; }; transform::Pass RelayToTIR() { @@ -124,8 +135,4 @@ transform::Pass RelayToTIR() { } // namespace contrib } // namespace relay -TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) - .set_attr("RelayToTIR", - relay::contrib::example_target_hooks::RelayToTIR()); - } // namespace tvm diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc b/src/relay/backend/contrib/example_target_hooks/target.cc new file mode 100644 index 000000000000..75b161ad4499 --- /dev/null +++ b/src/relay/backend/contrib/example_target_hooks/target.cc @@ -0,0 +1,39 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +namespace tvm { + +namespace relay { +namespace contrib { +namespace example_target_hooks { +tvm::transform::Pass RelayToTIR(); +runtime::Module TIRToRuntime(IRModule mod, Target target); +} // namespace example_target_hooks +} // namespace contrib +} // namespace relay + +TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) + .set_attr("RelayToTIR", relay::contrib::example_target_hooks::RelayToTIR()) + .set_attr("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime); + +} // namespace tvm diff --git a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc new file mode 100644 index 000000000000..36d801d349a7 --- /dev/null +++ b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "../../../../target/source/codegen_c_host.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace example_target_hooks { + +using namespace tir; + +class CodeGenExampleTargetHook : public codegen::CodeGenCHost { + public: + /*! + * \brief Emit code that changes adds to multiplies for testing + */ + void VisitExpr_(const SubNode* op, std::ostream& os) final { + os << '('; + PrintExpr(op->a, os); + os << " * "; + PrintExpr(op->b, os); + os << ')'; + } +}; + +runtime::Module TIRToRuntime(IRModule mod, Target target) { + bool output_ssa = false; + bool emit_asserts = false; + CodeGenExampleTargetHook codegen; + Array function_names; + codegen.Init(output_ssa, emit_asserts, target->str()); + for (auto kv : mod->functions) { + auto prim_func = Downcast(kv.second); + auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + function_names.push_back(global_symbol.value()); + codegen.AddFunction(prim_func); + } + std::string code = codegen.Finish(); + return codegen::CSourceModuleCreate(code, "c", function_names); +} + +} // namespace example_target_hooks +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 5a4aa39f01b4..41221ad8a33e 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -47,6 +47,11 @@ runtime::Module Build(IRModule mod, Target target) { mod = tir::transform::SkipAssert()(mod); } + auto target_attr_map = tvm::TargetKind::GetAttrMap("TIRToRuntime"); + if (target_attr_map.count(target->kind)) { + return target_attr_map[target->kind](mod, target); + } + // the build function. std::string build_f_name = "target.build." + target->kind->name; const PackedFunc* bf = runtime::Registry::Get(build_f_name); diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 10a437a547c1..4ff1c6ef61ed 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -35,7 +35,7 @@ namespace tvm { namespace codegen { -class CodeGenCHost final : public CodeGenC { +class CodeGenCHost : public CodeGenC { public: CodeGenCHost(); void Init(bool output_ssa, bool emit_asserts, std::string target_str); diff --git a/tests/python/relay/test_target_hooks.py b/tests/python/relay/test_target_hooks.py index 4d7a7fcdc15b..5856dc1e1c69 100644 --- a/tests/python/relay/test_target_hooks.py +++ b/tests/python/relay/test_target_hooks.py @@ -49,5 +49,28 @@ def test_tir_external_generation(check_result): check_result(func, inputs, (8,), x_data - y_data) +@pytest.mark.parametrize("check_result", [check_aot_executor_result, check_graph_executor_result]) +def test_runtime_module_generation(check_result): + shape = (8,) + x_data = np.random.randint(255, size=shape).astype("float32") + y_data = np.random.randint(255, size=shape).astype("float32") + inputs = {"x": x_data, "y": y_data} + + x0 = relay.var("x0", shape=shape, dtype="float32") + y0 = relay.var("y0", shape=shape, dtype="float32") + z = x0 + y0 + func = relay.Function([x0, y0], z) + func = set_external_func_attr(func, "example_target_hook", "replace_add_with_subtract") + # Test hook to trigger TIRToRuntime code generation + func = func.with_attr("tir_to_runtime", True) + + x = relay.var("x", shape=(8,), dtype="float32") + y = relay.var("y", shape=(8,), dtype="float32") + call = relay.Call(func, [x, y]) + func = IRModule.from_expr(call) + + check_result(func, inputs, (8,), x_data * y_data) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))