Skip to content

Commit

Permalink
Implementation of relay_to_tir target hook
Browse files Browse the repository at this point in the history
This the first new hook proposed in the Additional Target Hooks RFC, longer
term the compilation should move to using `Target` proper but this unblocks our current work whilst illustrating the eventual interface via `Target` in `target_kind.cc`

I've encapsulated the hook lookup into a method on `TargetKind` (`GetRegisteredHook`), which will eventually mean that the logic can be compacted to:
```
func->target->kind.GetRegisteredHook()
```
  • Loading branch information
Mousius committed Aug 13, 2021
1 parent ae9db49 commit 997ae2d
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 9 deletions.
7 changes: 7 additions & 0 deletions include/tvm/target/target_kind.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ class TargetKind : public ObjectRef {
TVM_DLL static Optional<TargetKind> Get(const String& target_kind_name);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetKind, ObjectRef, TargetKindNode);

/*!
* \brief Look up for TargetKind registered hooks
* \param hook_name Name of the registered hook
* \return The associated PackedFunc for the hook
*/
TVM_DLL const PackedFunc* GetRegisteredHook(const String& hook_name) const;

private:
/*! \brief Mutable access to the container class */
TargetKindNode* operator->() { return static_cast<TargetKindNode*>(data_.get()); }
Expand Down
49 changes: 40 additions & 9 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ using namespace tvm::relay::transform;

TVM_REGISTER_OBJECT_TYPE(TECompilerNode);

/*!
* \brief Get target hook from function after checking TargetKind registry
*
* \param func - Function to get hook from
* \param hook_name - Name of hook to acquire
* \return Pointer to the packed function in the registry or nullptr if not found
*/
const PackedFunc* GetTargetHookFromFunction(const Function& func, const String& hook_name) {
auto code_gen_name = func->GetAttr<String>(attr::kCompiler).value();
auto target_kind = tvm::TargetKind::Get(code_gen_name);
if (target_kind) {
return target_kind.value().GetRegisteredHook(hook_name);
}
return nullptr;
}

class TECompilerImpl : public TECompilerNode {
public:
// Lower the function.
Expand Down Expand Up @@ -112,10 +128,12 @@ class TECompilerImpl : public TECompilerNode {
auto src_func = it.first->source_func;
ICHECK(src_func.defined());
if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
std::string code_gen_name = code_gen.value();
// Skip this function if it was actually lowered to TIR instead of a Runtime Module
if (GetTargetHookFromFunction(src_func, "relay_to_tir") != nullptr) {
continue;
}
cached_ext_funcs.push_back(it.first);

auto code_gen_name = src_func->GetAttr<String>(attr::kCompiler).value();
auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
<< AsText(src_func, false);
Expand Down Expand Up @@ -185,17 +203,28 @@ class TECompilerImpl : public TECompilerNode {
}
cur_ccache_key_ = key;

// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
auto ir_module = IRModule();
const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
auto func_name = GetUniqueName(name_node.value(), &name_map_);
auto target = Target("ext_dev");
auto global_var = GlobalVar(func_name);
global_var->checked_type_ = key->source_func->checked_type();

auto ir_module = IRModule();
ir_module->Add(global_var, key->source_func);

// Lower to TIR if we have a registered lowering hook
auto custom_lowering_to_tir = GetTargetHookFromFunction(key->source_func, "relay_to_tir");
if (custom_lowering_to_tir != nullptr) {
IRModule lowered_module = (*custom_lowering_to_tir)(ir_module, key->source_func);
value->cached_func =
CachedFunc(key->target, global_var, {}, {}, te::Schedule(), {}, lowered_module);
return value;
}

// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
auto target = Target("ext_dev");
value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
return value;
}
Expand Down Expand Up @@ -340,7 +369,9 @@ class LowerTensorExpr : public ExprMutator {

Target target;

if (func->GetAttr<String>(attr::kCompiler).defined()) {
// If a custom lowering hook is registered, it will be resolved during the call to Lower()
if (func->GetAttr<String>(attr::kCompiler).defined() &&
GetTargetHookFromFunction(func, "relay_to_tir") == nullptr) {
target = Target("ext_dev");
CCacheKey key = CCacheKey(func, target);
CachedFunc ext_func = compiler_->Lower(key, module_name_);
Expand Down Expand Up @@ -414,7 +445,7 @@ class LowerTensorExpr : public ExprMutator {
ProcessFn process_fn;
String module_name_;
TECompiler compiler_;
};
}; // namespace tec

/*!
* \brief Obtain the Target from the device type.
Expand Down
12 changes: 12 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ Optional<TargetKind> TargetKind::Get(const String& target_kind_name) {
return reg->kind_;
}

const PackedFunc* TargetKind::GetRegisteredHook(const String& hook_name) const {
auto map = tvm::TargetKind::GetAttrMap<String>(hook_name);
if (map.count(*this)) {
std::string hook_function = map[*this];
return tvm::runtime::Registry::Get(hook_function);
}
return nullptr;
}

/********** Utility functions **********/

/*!
Expand Down Expand Up @@ -353,6 +362,9 @@ TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break

TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option<Array<Target>>("devices");

TVM_REGISTER_TARGET_KIND("test", kDLCPU)
.set_attr<String>("relay_to_tir", "target.test.tir_lowering");

/********** Registry **********/

TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds);
Expand Down
19 changes: 19 additions & 0 deletions tests/cpp/target_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@

using namespace tvm;

TVM_REGISTER_GLOBAL("target.test_kind.test_registered_function")
.set_body_typed([](IRModule mod, Target target) { return mod; });

TVM_REGISTER_TARGET_KIND("TestTargetKind", kDLCPU)
.set_attr<std::string>("Attr1", "Value1")
.set_attr<String>("known_hook", "target.test_kind.test_registered_function")
.set_attr<String>("unknown_hook", "target.test_kind.test_not_registered_function")
.add_attr_option<Bool>("my_bool")
.add_attr_option<Array<String>>("your_names")
.add_attr_option<Map<String, Integer>>("her_maps");
Expand Down Expand Up @@ -158,6 +163,20 @@ TEST(TargetKindRegistryListTargetKinds, Basic) {
ICHECK_EQ(std::count(std::begin(names), std::end(names), "llvm"), 1);
}

TEST(TargetHookCheck, HookRegisteredNonNull) {
auto target_kind = tvm::TargetKind::Get("TestTargetKind").value();
const PackedFunc* target_hook =
tvm::runtime::Registry::Get("target.test_kind.test_registered_function");
ICHECK_NE(target_hook, (const PackedFunc*)nullptr);
ICHECK_EQ(target_kind.GetRegisteredHook("known_hook"), target_hook);
}

TEST(TargetHookCheck, HookRegisteredNull) {
auto target_kind = tvm::TargetKind::Get("TestTargetKind").value();
const PackedFunc* unknown_func = nullptr;
ICHECK_EQ(target_kind.GetRegisteredHook("unknown_hook"), unknown_func);
}

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
Expand Down
88 changes: 88 additions & 0 deletions tests/python/relay/test_target_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for target hooks."""
import sys
import numpy as np
import pytest

import tvm
import tvm.relay.testing
import tvm.relay.transform

from tvm import relay
from utils.external_codegen import (
set_external_func_attr,
check_aot_executor_result,
check_graph_executor_result,
)


def translate_relay_add_to_tir_subtract(ir_module, relay_func):
"""A transform to test Relay -> TIR with"""
ib = tvm.tir.ir_builder.create()
A = tvm.tir.decl_buffer(
dtype=relay_func.params[0].checked_type.dtype,
name=relay_func.params[0].name_hint,
shape=relay_func.params[0].checked_type.shape,
)
B = tvm.tir.decl_buffer(
dtype=relay_func.params[1].checked_type.dtype,
name=relay_func.params[1].name_hint,
shape=relay_func.params[1].checked_type.shape,
)
C = tvm.tir.decl_buffer(dtype=relay_func.ret_type.dtype, shape=relay_func.ret_type.shape)

Ap = ib.buffer_ptr(A)
Bp = ib.buffer_ptr(B)
Cp = ib.buffer_ptr(C)

with ib.for_range(0, 8, name="i") as i:
with ib.for_range(0, 8, name="j") as j:
row = i * 8
Cp[row + j] = Ap[row + j] - Bp[row + j]

prim_func = tvm.tir.PrimFunc([A, B, C], ib.get())

ir_module = tvm.lower(prim_func, name=relay_func.attrs["global_symbol"])
return ir_module


@pytest.mark.parametrize("check_result", [check_graph_executor_result, check_aot_executor_result])
def test_tir_external_generation(check_result):
tvm.register_func("target.test.tir_lowering", translate_relay_add_to_tir_subtract, True)

shape = (8, 8)
x_data = np.random.randint(255, size=shape).astype("float32")
y_data = np.random.randint(255, size=shape).astype("float32")
inputs = {"x": x_data, "y": y_data}

x0 = relay.var("x0", shape=shape, dtype="float32")
y0 = relay.var("y0", shape=shape, dtype="float32")
z = x0 + y0
f = relay.Function([x0, y0], z)
f = set_external_func_attr(f, "test", "replace_add_with_subtract")

x = relay.var("x", shape=(8, 8), dtype="float32")
y = relay.var("y", shape=(8, 8), dtype="float32")
call = relay.Call(f, [x, y])
func = tvm.IRModule.from_expr(call)

check_result(func, inputs, (8, 8), x_data - y_data)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 997ae2d

Please sign in to comment.