diff --git a/CMakeLists.txt b/CMakeLists.txt index 81c16a2dbdca..c1c068cffa68 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -433,6 +433,7 @@ include(cmake/modules/contrib/EthosU.cmake) include(cmake/modules/contrib/BLAS.cmake) include(cmake/modules/contrib/CODEGENC.cmake) include(cmake/modules/contrib/DNNL.cmake) +include(cmake/modules/contrib/ExampleTargetHooks.cmake) include(cmake/modules/contrib/Random.cmake) include(cmake/modules/contrib/Posit.cmake) include(cmake/modules/contrib/MicroStandaloneRuntime.cmake) diff --git a/cmake/modules/contrib/ExampleTargetHooks.cmake b/cmake/modules/contrib/ExampleTargetHooks.cmake new file mode 100644 index 000000000000..eb53dda133d2 --- /dev/null +++ b/cmake/modules/contrib/ExampleTargetHooks.cmake @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +file(GLOB EXAMPLE_TARGET_HOOKS_SRC src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc) +list(APPEND COMPILER_SRCS ${EXAMPLE_TARGET_HOOKS_SRC}) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index bdc46d71a77d..912879dc8a4b 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -426,6 +426,13 @@ TVM_DLL Pass RemoveUnusedFunctions(Array entry_functions); */ TVM_DLL Pass SimplifyExpr(); +/*! + * \brief Run any registered RelayToTIR passes registered on the functions in a module. + * + * \return The pass. + */ +TVM_DLL Pass RelayToTIRTargetHook(); + /*! * \brief A pass for manifesting explicit memory allocations and rewriting * specific dialects. diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc new file mode 100644 index 000000000000..6d332803041d --- /dev/null +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -0,0 +1,131 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace contrib { +namespace example_target_hooks { + +class ConvertAddToSubtract : public MixedModeMutator { + public: + explicit ConvertAddToSubtract(IRModule ir_module, Target host_target) + : ir_module_(ir_module), host_target_(host_target) {} + + IRModule Mutate() { + GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); + BaseFunc main = ir_module_->Lookup(main_global_var); + Function main_func = GetRef(main.as()); + + // Copy everything across and mutate the body + Function mutated_main = + Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type, + main_func->type_params, main_func->attrs, main_func->span); + + ir_module_->Update(main_global_var, mutated_main); + + return ir_module_; + } + + private: + tir::Load LoadIndex(const tir::Buffer& buffer, const PrimExpr& index) { + return tir::Load(DataType::Float(32), buffer->data, index, tir::const_true()); + } + + void ReplaceAddWithSubtractPrimFunc(const GlobalVar& new_global_var, const Function& func) { + tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x"); + tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y"); + tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32)); + + tir::Var x_var("x", DataType::Handle()); + tir::Var y_var("y", DataType::Handle()); + tir::Var out_var("out", DataType::Handle()); + + Map dict_attrs; + dict_attrs.Set("global_symbol", new_global_var->name_hint); + dict_attrs.Set("tir.noalias", Bool(true)); + + te::Var index("index", DataType::Int(32)); + tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index)); + tir::Stmt math_body = tir::Store(out_buffer->data, indexed_sub, index, tir::const_true()); + tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body); + + Map buffer_map = { + {x_var, x_buffer}, + {y_var, y_buffer}, + {out_var, out_buffer}, + }; + + tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(), + buffer_map, DictAttrs(dict_attrs)); + replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_); + ir_module_->Add(new_global_var, replacement_func); + } + + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + if (const CallNode* call = post.as()) { + auto* func = call->op.as(); + if (func == nullptr) { + return post; + } + + auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); + if (func_name.defined() && func_name == "replace_add_with_subtract") { + // Introduce a new global var to map the function to and copy the source type + // over for InferType + GlobalVar new_global_var(func_name.value()); + new_global_var->checked_type_ = func->checked_type(); + ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef(func)); + return Call(new_global_var, call->args, call->attrs, call->type_args, call->span); + } + } + + return post; + } + + public: + IRModule ir_module_; + Target host_target_; +}; + +transform::Pass RelayToTIR() { + runtime::TypedPackedFunc pass_func = + [=](IRModule ir_module, transform::PassContext pass_context) { + auto relay_to_tir = ConvertAddToSubtract(ir_module, Target("c")); + return relay_to_tir.Mutate(); + }; + return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {}); +} + +} // namespace example_target_hooks +} // namespace contrib +} // namespace relay + +TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) + .set_attr("RelayToTIR", + relay::contrib::example_target_hooks::RelayToTIR()); + +} // namespace tvm diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 2e7eb6f9aa6b..e322ccaff1ce 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -131,6 +131,7 @@ class TECompilerImpl : public TECompilerNode { Array ret; std::unordered_map cached_symbol; std::vector cached_ext_funcs; + for (const auto& it : cache_) { auto src_func = it.first->source_func; ICHECK(src_func.defined()); @@ -383,10 +384,12 @@ class LowerTensorExprMutator : public ExprMutator { * \brief Returns the primitive function associated with \p expr, or * nullptr if none. */ - Function ResolveToPrimitive(Expr expr) { + BaseFunc ResolveToPrimitive(Expr expr) { if (const GlobalVarNode* gvn = expr.as()) { BaseFunc base_func = module_->Lookup(GetRef(gvn)); return ResolveToPrimitive(base_func); + } else if (const tir::PrimFuncNode* prim_func = expr.as()) { + return GetRef(prim_func); } else if (const VarNode* vn = expr.as()) { auto itr = primitive_functions_.find(GetRef(vn)); return itr == primitive_functions_.end() ? Function() : itr->second; @@ -516,10 +519,17 @@ class LowerTensorExprMutator : public ExprMutator { Expr VisitExpr_(const LetNode* let) override { Var var = Downcast(Mutate(let->var)); Expr value = Mutate(let->value); - Function prim_func = ResolveToPrimitive(value); + BaseFunc prim_func = ResolveToPrimitive(value); + if (prim_func.defined()) { + // Already lowered by other means, no need to mutate the Let node + if (prim_func->IsInstance()) { + return GetRef(let); + } + // Remember let var is bound to (possibly indirectly) to a primitive. - primitive_functions_.emplace(let->var, prim_func); + Function func = Downcast(prim_func); + primitive_functions_.emplace(let->var, func); } Expr body = Mutate(let->body); if (prim_func.defined()) { @@ -537,7 +547,7 @@ class LowerTensorExprMutator : public ExprMutator { Call expr = GetRef(call); // Look for (indirect) calls to primitives. - Function prim_func = ResolveToPrimitive(call->op); + BaseFunc prim_func = ResolveToPrimitive(call->op); if (!prim_func.defined()) { // Not a call to a primitive function. if (const FunctionNode* fn = call->op.as()) { @@ -546,6 +556,12 @@ class LowerTensorExprMutator : public ExprMutator { return ExprMutator::VisitExpr_(call); } + // Already lowered by other means so we don't need to mutate + // the call + if (prim_func->IsInstance()) { + return expr; + } + // Find the desired target device. Target target; if (prim_func->GetAttr(attr::kCompiler).defined()) { @@ -565,7 +581,8 @@ class LowerTensorExprMutator : public ExprMutator { } // Lower the primitive function for that target. - std::pair pair = LowerFunction(prim_func, target); + Function func = Downcast(prim_func); + std::pair pair = LowerFunction(func, target); // Similarly transform arguments. Array args; @@ -639,8 +656,6 @@ Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, const Stri backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMap targets, Map storage_info_map) { - CHECK_EQ(mod->functions.size(), 1) - << "There should only be one function in the module passed to UpdateMainWorkspaceSize"; Function func = Downcast(mod->Lookup("main")); // This is a Map> @@ -909,8 +924,10 @@ Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, const String& PassContext ctx) { return LowerTE(module, targets, device_context_map, module_name, process_fn); }; - return tvm::transform::Sequential( - {tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}), InferType()}); + + return tvm::transform::Sequential({tvm::relay::transform::RelayToTIRTargetHook(), + tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}), + InferType()}); } } // namespace tec } // namespace relay diff --git a/src/relay/transforms/target_hooks.cc b/src/relay/transforms/target_hooks.cc new file mode 100644 index 000000000000..40287ded1dd8 --- /dev/null +++ b/src/relay/transforms/target_hooks.cc @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file target_hooks.cc + * \brief Relay passes for processing Target Hooks which have been registered on functions within + * the IRModule + */ + +#include +#include + +namespace tvm { +namespace relay { +namespace transform { + +class TargetHookVisitor : public tvm::relay::MixedModeVisitor { + /*! \brief Collected pass list for all nodes */ + std::vector pass_list_; + /*! \brief Attribute map for all registered targets */ + TargetKindAttrMap target_attr_map_; + + public: + TargetHookVisitor() : target_attr_map_(tvm::TargetKind::GetAttrMap("RelayToTIR")) {} + + std::vector Visit(const IRModule& ir_mod) { + for (const auto& it : ir_mod->functions) { + const BaseFunc& base_func = it.second; + VisitExpr(base_func); + } + return pass_list_; + } + + void VisitExpr_(const CallNode* call) override { + // Descend the call tree + for (auto arg : call->args) { + VisitExpr(arg); + } + + if (const FunctionNode* func = call->op.as()) { + if (!func->GetAttr(attr::kCompiler).defined()) { + return; + } + String code_gen_name = func->GetAttr(attr::kCompiler).value(); + Optional target_kind = tvm::TargetKind::Get(code_gen_name); + if (!target_kind || !target_attr_map_.count(target_kind.value())) { + return; + } + Pass custom_target_pass = target_attr_map_[target_kind.value()]; + if (std::find(pass_list_.begin(), pass_list_.end(), custom_target_pass) == pass_list_.end()) { + pass_list_.push_back(custom_target_pass); + } + } + } +}; + +Pass RelayToTIRTargetHook() { + auto pass_func = [=](IRModule mod, const PassContext& pass_ctx) { + auto target_hook_visitor = TargetHookVisitor(); + std::vector pass_list = target_hook_visitor.Visit(mod); + Sequential run_hooks(pass_list); + + return run_hooks(mod); + }; + return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIRTargetHook", {}); +} + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_target_hooks.py b/tests/python/relay/test_target_hooks.py new file mode 100644 index 000000000000..4d7a7fcdc15b --- /dev/null +++ b/tests/python/relay/test_target_hooks.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for target hooks.""" +import sys +import numpy as np +import pytest + +from tvm import relay, IRModule + +from utils.external_codegen import ( + set_external_func_attr, + check_aot_executor_result, + check_graph_executor_result, +) + + +@pytest.mark.parametrize("check_result", [check_aot_executor_result, check_graph_executor_result]) +def test_tir_external_generation(check_result): + shape = (8,) + x_data = np.random.randint(255, size=shape).astype("float32") + y_data = np.random.randint(255, size=shape).astype("float32") + inputs = {"x": x_data, "y": y_data} + + x0 = relay.var("x0", shape=shape, dtype="float32") + y0 = relay.var("y0", shape=shape, dtype="float32") + z = x0 + y0 + f = relay.Function([x0, y0], z) + f = set_external_func_attr(f, "example_target_hook", "replace_add_with_subtract") + + x = relay.var("x", shape=(8,), dtype="float32") + y = relay.var("y", shape=(8,), dtype="float32") + call = relay.Call(f, [x, y]) + func = IRModule.from_expr(call) + + check_result(func, inputs, (8,), x_data - y_data) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))