Skip to content

Commit

Permalink
Implementation of relay_to_tir target hook (apache#8423)
Browse files Browse the repository at this point in the history
This the first new hook proposed in the Additional Target Hooks RFC, longer
term the compilation should move to using `Target` proper but this unblocks our current work whilst illustrating the eventual interface via `Target` in `src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc`

Ideally the host target would be annotated onto the `IRModule` so as this `Pass` could use it instead of defaulting to C but this is fine for now.
  • Loading branch information
Mousius authored and ylc committed Sep 29, 2021
1 parent b0e0489 commit 0cdfb86
Show file tree
Hide file tree
Showing 7 changed files with 323 additions and 9 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ include(cmake/modules/contrib/EthosU.cmake)
include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/CODEGENC.cmake)
include(cmake/modules/contrib/DNNL.cmake)
include(cmake/modules/contrib/ExampleTargetHooks.cmake)
include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/Posit.cmake)
include(cmake/modules/contrib/MicroStandaloneRuntime.cmake)
Expand Down
19 changes: 19 additions & 0 deletions cmake/modules/contrib/ExampleTargetHooks.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

file(GLOB EXAMPLE_TARGET_HOOKS_SRC src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc)
list(APPEND COMPILER_SRCS ${EXAMPLE_TARGET_HOOKS_SRC})
7 changes: 7 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,13 @@ TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
*/
TVM_DLL Pass SimplifyExpr();

/*!
* \brief Run any registered RelayToTIR passes registered on the functions in a module.
*
* \return The pass.
*/
TVM_DLL Pass RelayToTIRTargetHook();

/*!
* \brief A pass for manifesting explicit memory allocations and rewriting
* specific dialects.
Expand Down
131 changes: 131 additions & 0 deletions src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>

namespace tvm {
namespace relay {
namespace contrib {
namespace example_target_hooks {

class ConvertAddToSubtract : public MixedModeMutator {
public:
explicit ConvertAddToSubtract(IRModule ir_module, Target host_target)
: ir_module_(ir_module), host_target_(host_target) {}

IRModule Mutate() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
BaseFunc main = ir_module_->Lookup(main_global_var);
Function main_func = GetRef<Function>(main.as<FunctionNode>());

// Copy everything across and mutate the body
Function mutated_main =
Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type,
main_func->type_params, main_func->attrs, main_func->span);

ir_module_->Update(main_global_var, mutated_main);

return ir_module_;
}

private:
tir::Load LoadIndex(const tir::Buffer& buffer, const PrimExpr& index) {
return tir::Load(DataType::Float(32), buffer->data, index, tir::const_true());
}

void ReplaceAddWithSubtractPrimFunc(const GlobalVar& new_global_var, const Function& func) {
tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x");
tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y");
tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32));

tir::Var x_var("x", DataType::Handle());
tir::Var y_var("y", DataType::Handle());
tir::Var out_var("out", DataType::Handle());

Map<String, ObjectRef> dict_attrs;
dict_attrs.Set("global_symbol", new_global_var->name_hint);
dict_attrs.Set("tir.noalias", Bool(true));

te::Var index("index", DataType::Int(32));
tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index));
tir::Stmt math_body = tir::Store(out_buffer->data, indexed_sub, index, tir::const_true());
tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body);

Map<tir::Var, tir::Buffer> buffer_map = {
{x_var, x_buffer},
{y_var, y_buffer},
{out_var, out_buffer},
};

tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(),
buffer_map, DictAttrs(dict_attrs));
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_);
ir_module_->Add(new_global_var, replacement_func);
}

Expr Rewrite_(const CallNode* pre, const Expr& post) override {
if (const CallNode* call = post.as<CallNode>()) {
auto* func = call->op.as<FunctionNode>();
if (func == nullptr) {
return post;
}

auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol);
if (func_name.defined() && func_name == "replace_add_with_subtract") {
// Introduce a new global var to map the function to and copy the source type
// over for InferType
GlobalVar new_global_var(func_name.value());
new_global_var->checked_type_ = func->checked_type();
ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef<Function>(func));
return Call(new_global_var, call->args, call->attrs, call->type_args, call->span);
}
}

return post;
}

public:
IRModule ir_module_;
Target host_target_;
};

transform::Pass RelayToTIR() {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[=](IRModule ir_module, transform::PassContext pass_context) {
auto relay_to_tir = ConvertAddToSubtract(ir_module, Target("c"));
return relay_to_tir.Mutate();
};
return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {});
}

} // namespace example_target_hooks
} // namespace contrib
} // namespace relay

TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU)
.set_attr<tvm::transform::Pass>("RelayToTIR",
relay::contrib::example_target_hooks::RelayToTIR());

} // namespace tvm
35 changes: 26 additions & 9 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class TECompilerImpl : public TECompilerNode {
Array<tvm::runtime::Module> ret;
std::unordered_map<std::string, std::string> cached_symbol;
std::vector<CCacheKey> cached_ext_funcs;

for (const auto& it : cache_) {
auto src_func = it.first->source_func;
ICHECK(src_func.defined());
Expand Down Expand Up @@ -383,10 +384,12 @@ class LowerTensorExprMutator : public ExprMutator {
* \brief Returns the primitive function associated with \p expr, or
* nullptr if none.
*/
Function ResolveToPrimitive(Expr expr) {
BaseFunc ResolveToPrimitive(Expr expr) {
if (const GlobalVarNode* gvn = expr.as<GlobalVarNode>()) {
BaseFunc base_func = module_->Lookup(GetRef<GlobalVar>(gvn));
return ResolveToPrimitive(base_func);
} else if (const tir::PrimFuncNode* prim_func = expr.as<tir::PrimFuncNode>()) {
return GetRef<tir::PrimFunc>(prim_func);
} else if (const VarNode* vn = expr.as<VarNode>()) {
auto itr = primitive_functions_.find(GetRef<Var>(vn));
return itr == primitive_functions_.end() ? Function() : itr->second;
Expand Down Expand Up @@ -516,10 +519,17 @@ class LowerTensorExprMutator : public ExprMutator {
Expr VisitExpr_(const LetNode* let) override {
Var var = Downcast<Var>(Mutate(let->var));
Expr value = Mutate(let->value);
Function prim_func = ResolveToPrimitive(value);
BaseFunc prim_func = ResolveToPrimitive(value);

if (prim_func.defined()) {
// Already lowered by other means, no need to mutate the Let node
if (prim_func->IsInstance<tir::PrimFuncNode>()) {
return GetRef<Let>(let);
}

// Remember let var is bound to (possibly indirectly) to a primitive.
primitive_functions_.emplace(let->var, prim_func);
Function func = Downcast<Function>(prim_func);
primitive_functions_.emplace(let->var, func);
}
Expr body = Mutate(let->body);
if (prim_func.defined()) {
Expand All @@ -537,7 +547,7 @@ class LowerTensorExprMutator : public ExprMutator {
Call expr = GetRef<Call>(call);

// Look for (indirect) calls to primitives.
Function prim_func = ResolveToPrimitive(call->op);
BaseFunc prim_func = ResolveToPrimitive(call->op);
if (!prim_func.defined()) {
// Not a call to a primitive function.
if (const FunctionNode* fn = call->op.as<FunctionNode>()) {
Expand All @@ -546,6 +556,12 @@ class LowerTensorExprMutator : public ExprMutator {
return ExprMutator::VisitExpr_(call);
}

// Already lowered by other means so we don't need to mutate
// the call
if (prim_func->IsInstance<tir::PrimFuncNode>()) {
return expr;
}

// Find the desired target device.
Target target;
if (prim_func->GetAttr<String>(attr::kCompiler).defined()) {
Expand All @@ -565,7 +581,8 @@ class LowerTensorExprMutator : public ExprMutator {
}

// Lower the primitive function for that target.
std::pair<GlobalVar, Attrs> pair = LowerFunction(prim_func, target);
Function func = Downcast<Function>(prim_func);
std::pair<GlobalVar, Attrs> pair = LowerFunction(func, target);

// Similarly transform arguments.
Array<Expr> args;
Expand Down Expand Up @@ -639,8 +656,6 @@ Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, const Stri

backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMap targets,
Map<Expr, backend::StorageInfo> storage_info_map) {
CHECK_EQ(mod->functions.size(), 1)
<< "There should only be one function in the module passed to UpdateMainWorkspaceSize";
Function func = Downcast<Function>(mod->Lookup("main"));

// This is a Map<device,Map<storage_id, size>>
Expand Down Expand Up @@ -909,8 +924,10 @@ Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, const String&
PassContext ctx) {
return LowerTE(module, targets, device_context_map, module_name, process_fn);
};
return tvm::transform::Sequential(
{tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}), InferType()});

return tvm::transform::Sequential({tvm::relay::transform::RelayToTIRTargetHook(),
tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}),
InferType()});
}
} // namespace tec
} // namespace relay
Expand Down
86 changes: 86 additions & 0 deletions src/relay/transforms/target_hooks.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file target_hooks.cc
* \brief Relay passes for processing Target Hooks which have been registered on functions within
* the IRModule
*/

#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>

namespace tvm {
namespace relay {
namespace transform {

class TargetHookVisitor : public tvm::relay::MixedModeVisitor {
/*! \brief Collected pass list for all nodes */
std::vector<Pass> pass_list_;
/*! \brief Attribute map for all registered targets */
TargetKindAttrMap<Pass> target_attr_map_;

public:
TargetHookVisitor() : target_attr_map_(tvm::TargetKind::GetAttrMap<Pass>("RelayToTIR")) {}

std::vector<Pass> Visit(const IRModule& ir_mod) {
for (const auto& it : ir_mod->functions) {
const BaseFunc& base_func = it.second;
VisitExpr(base_func);
}
return pass_list_;
}

void VisitExpr_(const CallNode* call) override {
// Descend the call tree
for (auto arg : call->args) {
VisitExpr(arg);
}

if (const FunctionNode* func = call->op.as<FunctionNode>()) {
if (!func->GetAttr<String>(attr::kCompiler).defined()) {
return;
}
String code_gen_name = func->GetAttr<String>(attr::kCompiler).value();
Optional<TargetKind> target_kind = tvm::TargetKind::Get(code_gen_name);
if (!target_kind || !target_attr_map_.count(target_kind.value())) {
return;
}
Pass custom_target_pass = target_attr_map_[target_kind.value()];
if (std::find(pass_list_.begin(), pass_list_.end(), custom_target_pass) == pass_list_.end()) {
pass_list_.push_back(custom_target_pass);
}
}
}
};

Pass RelayToTIRTargetHook() {
auto pass_func = [=](IRModule mod, const PassContext& pass_ctx) {
auto target_hook_visitor = TargetHookVisitor();
std::vector<Pass> pass_list = target_hook_visitor.Visit(mod);
Sequential run_hooks(pass_list);

return run_hooks(mod);
};
return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIRTargetHook", {});
}

} // namespace transform
} // namespace relay
} // namespace tvm
Loading

0 comments on commit 0cdfb86

Please sign in to comment.