diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 8a2bbcbd01213..4c6388932d774 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -111,6 +111,13 @@ class TargetKind : public ObjectRef { TVM_DLL static Optional Get(const String& target_kind_name); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetKind, ObjectRef, TargetKindNode); + /*! + * \brief Look up for TargetKind registered hooks + * \param hook_name Name of the registered hook + * \return The associated PackedFunc for the hook + */ + TVM_DLL const PackedFunc* GetRegisteredHook(const String& hook_name) const; + private: /*! \brief Mutable access to the container class */ TargetKindNode* operator->() { return static_cast(data_.get()); } diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 7840960ec268e..4fe5d6c61ca80 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -59,6 +59,22 @@ using namespace tvm::relay::transform; TVM_REGISTER_OBJECT_TYPE(TECompilerNode); +/*! + * \brief Get target hook from function after checking TargetKind registry + * + * \param func - Function to get hook from + * \param hook_name - Name of hook to acquire + * \return Pointer to the packed function in the registry or nullptr if not found + */ +const PackedFunc* GetTargetHookFromFunction(const Function& func, const String& hook_name) { + auto code_gen_name = func->GetAttr(attr::kCompiler).value(); + auto target_kind = tvm::TargetKind::Get(code_gen_name); + if (target_kind) { + return target_kind.value().GetRegisteredHook(hook_name); + } + return nullptr; +} + class TECompilerImpl : public TECompilerNode { public: // Lower the function. @@ -112,10 +128,12 @@ class TECompilerImpl : public TECompilerNode { auto src_func = it.first->source_func; ICHECK(src_func.defined()); if (src_func->GetAttr(attr::kCompiler).defined()) { - auto code_gen = src_func->GetAttr(attr::kCompiler); - std::string code_gen_name = code_gen.value(); + // Skip this function if it was actually lowered to TIR instead of a Runtime Module + if (GetTargetHookFromFunction(src_func, "relay_to_tir") != nullptr) { + continue; + } cached_ext_funcs.push_back(it.first); - + auto code_gen_name = src_func->GetAttr(attr::kCompiler).value(); auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false); @@ -185,17 +203,28 @@ class TECompilerImpl : public TECompilerNode { } cur_ccache_key_ = key; - // No need to lower external functions for now. We will invoke the external - // codegen tool once and lower all functions together. if (key->source_func->GetAttr(attr::kCompiler).defined()) { - auto ir_module = IRModule(); const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "External function has not been attached a name yet."; auto func_name = GetUniqueName(name_node.value(), &name_map_); - auto target = Target("ext_dev"); auto global_var = GlobalVar(func_name); global_var->checked_type_ = key->source_func->checked_type(); + + auto ir_module = IRModule(); ir_module->Add(global_var, key->source_func); + + // Lower to TIR if we have a registered lowering hook + auto custom_lowering_to_tir = GetTargetHookFromFunction(key->source_func, "relay_to_tir"); + if (custom_lowering_to_tir != nullptr) { + IRModule lowered_module = (*custom_lowering_to_tir)(ir_module, key->source_func); + value->cached_func = + CachedFunc(key->target, global_var, {}, {}, te::Schedule(), {}, lowered_module); + return value; + } + + // No need to lower external functions for now. We will invoke the external + // codegen tool once and lower all functions together. + auto target = Target("ext_dev"); value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); return value; } @@ -340,7 +369,9 @@ class LowerTensorExpr : public ExprMutator { Target target; - if (func->GetAttr(attr::kCompiler).defined()) { + // If a custom lowering hook is registered, it will be resolved during the call to Lower() + if (func->GetAttr(attr::kCompiler).defined() && + GetTargetHookFromFunction(func, "relay_to_tir") == nullptr) { target = Target("ext_dev"); CCacheKey key = CCacheKey(func, target); CachedFunc ext_func = compiler_->Lower(key, module_name_); @@ -414,7 +445,7 @@ class LowerTensorExpr : public ExprMutator { ProcessFn process_fn; String module_name_; TECompiler compiler_; -}; +}; // namespace tec /*! * \brief Obtain the Target from the device type. diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 97317b5c48003..d95ed4beb00ee 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -70,6 +70,15 @@ Optional TargetKind::Get(const String& target_kind_name) { return reg->kind_; } +const PackedFunc* TargetKind::GetRegisteredHook(const String& hook_name) const { + auto map = tvm::TargetKind::GetAttrMap(hook_name); + if (map.count(*this)) { + std::string hook_function = map[*this]; + return tvm::runtime::Registry::Get(hook_function); + } + return nullptr; +} + /********** Utility functions **********/ /*! @@ -353,6 +362,9 @@ TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option>("devices"); +TVM_REGISTER_TARGET_KIND("test", kDLCPU) + .set_attr("relay_to_tir", "target.test.tir_lowering"); + /********** Registry **********/ TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds); diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 8dba462132ac6..b1ae3ef58a610 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -27,8 +27,13 @@ using namespace tvm; +TVM_REGISTER_GLOBAL("target.test_kind.test_registered_function") + .set_body_typed([](IRModule mod, Target target) { return mod; }); + TVM_REGISTER_TARGET_KIND("TestTargetKind", kDLCPU) .set_attr("Attr1", "Value1") + .set_attr("known_hook", "target.test_kind.test_registered_function") + .set_attr("unknown_hook", "target.test_kind.test_not_registered_function") .add_attr_option("my_bool") .add_attr_option>("your_names") .add_attr_option>("her_maps"); @@ -158,6 +163,20 @@ TEST(TargetKindRegistryListTargetKinds, Basic) { ICHECK_EQ(std::count(std::begin(names), std::end(names), "llvm"), 1); } +TEST(TargetHookCheck, HookRegisteredNonNull) { + auto target_kind = tvm::TargetKind::Get("TestTargetKind").value(); + const PackedFunc* target_hook = + tvm::runtime::Registry::Get("target.test_kind.test_registered_function"); + ICHECK_NE(target_hook, (const PackedFunc*)nullptr); + ICHECK_EQ(target_kind.GetRegisteredHook("known_hook"), target_hook); +} + +TEST(TargetHookCheck, HookRegisteredNull) { + auto target_kind = tvm::TargetKind::Get("TestTargetKind").value(); + const PackedFunc* unknown_func = nullptr; + ICHECK_EQ(target_kind.GetRegisteredHook("unknown_hook"), unknown_func); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/relay/test_target_hooks.py b/tests/python/relay/test_target_hooks.py new file mode 100644 index 0000000000000..61b3d4eab3212 --- /dev/null +++ b/tests/python/relay/test_target_hooks.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for target hooks.""" +import sys +import numpy as np +import pytest + +import tvm +import tvm.relay.testing +import tvm.relay.transform + +from tvm import relay +from utils.external_codegen import ( + set_external_func_attr, + check_aot_executor_result, + check_graph_executor_result, +) + + +def translate_relay_add_to_tir_subtract(ir_module, relay_func): + """A transform to test Relay -> TIR with""" + ib = tvm.tir.ir_builder.create() + A = tvm.tir.decl_buffer( + dtype=relay_func.params[0].checked_type.dtype, + name=relay_func.params[0].name_hint, + shape=relay_func.params[0].checked_type.shape, + ) + B = tvm.tir.decl_buffer( + dtype=relay_func.params[1].checked_type.dtype, + name=relay_func.params[1].name_hint, + shape=relay_func.params[1].checked_type.shape, + ) + C = tvm.tir.decl_buffer(dtype=relay_func.ret_type.dtype, shape=relay_func.ret_type.shape) + + Ap = ib.buffer_ptr(A) + Bp = ib.buffer_ptr(B) + Cp = ib.buffer_ptr(C) + + with ib.for_range(0, 8, name="i") as i: + with ib.for_range(0, 8, name="j") as j: + row = i * 8 + Cp[row + j] = Ap[row + j] - Bp[row + j] + + prim_func = tvm.tir.PrimFunc([A, B, C], ib.get()) + + ir_module = tvm.lower(prim_func, name=relay_func.attrs["global_symbol"]) + return ir_module + + +@pytest.mark.parametrize("check_result", [check_graph_executor_result, check_aot_executor_result]) +def test_tir_external_generation(check_result): + tvm.register_func("target.test.tir_lowering", translate_relay_add_to_tir_subtract, True) + + shape = (8, 8) + x_data = np.random.randint(255, size=shape).astype("float32") + y_data = np.random.randint(255, size=shape).astype("float32") + inputs = {"x": x_data, "y": y_data} + + x0 = relay.var("x0", shape=shape, dtype="float32") + y0 = relay.var("y0", shape=shape, dtype="float32") + z = x0 + y0 + f = relay.Function([x0, y0], z) + f = set_external_func_attr(f, "test", "replace_add_with_subtract") + + x = relay.var("x", shape=(8, 8), dtype="float32") + y = relay.var("y", shape=(8, 8), dtype="float32") + call = relay.Call(f, [x, y]) + func = tvm.IRModule.from_expr(call) + + check_result(func, inputs, (8, 8), x_data - y_data) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))