Skip to content

Commit

Permalink
Initial Implementation of TIRToRuntime Target hook (apache#9190)
Browse files Browse the repository at this point in the history
* Initial Implementation of TIRToRuntime Target hook

This is the initial implementation which wires in a test case for TIRToRuntime, in order to get this working I re-used `CodegenCHost` as it implements all of the `Op`s required from the lowered `PrimFunc`.

Currently, the `IRModule` is non-unified but in future work it should definitely do so, I wanted to implement the basics here to get the infra in place.

* Fix heterogeneous compute with multiple kDLCPU targets

* Remove rogue te_compiler.h include
  • Loading branch information
Mousius authored and ylc committed Jan 13, 2022
1 parent 3971ca0 commit 47e8368
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 11 deletions.
2 changes: 1 addition & 1 deletion cmake/modules/contrib/ExampleTargetHooks.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
# specific language governing permissions and limitations
# under the License.

file(GLOB EXAMPLE_TARGET_HOOKS_SRC src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc)
file(GLOB EXAMPLE_TARGET_HOOKS_SRC src/relay/backend/contrib/example_target_hooks/*.cc)
list(APPEND COMPILER_SRCS ${EXAMPLE_TARGET_HOOKS_SRC})
28 changes: 28 additions & 0 deletions include/tvm/target/target_kind.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#ifndef TVM_TARGET_TARGET_KIND_H_
#define TVM_TARGET_TARGET_KIND_H_

#include <tvm/ir/transform.h>
#include <tvm/node/attr_registry_map.h>
#include <tvm/node/node.h>

Expand All @@ -33,6 +34,33 @@
#include <vector>

namespace tvm {

class Target;

/*!
* \brief RelayToTIR tvm::transform::Pass specific to a TargetKind
*
* Called before the default lowering passes.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that can provide information for the optimization.
*
* \return The transformed module.
*/
using FTVMRelayToTIR = transform::Pass;

/*!
* \brief TIRToRuntime conversion specific to a TargetKind
*
* This function is responsible for scanning an IRModule for appropriate Target-specific functions
and generating a Runtime module representing the compiled output
*
* \param ir_module Unified IRModule
* \param target Target to filter on or retrieve arguments from
* \return Runtime Module containing compiled functions
*/
using FTVMTIRToRuntime = runtime::TypedPackedFunc<runtime::Module(IRModule, Target)>;

namespace detail {
template <typename, typename, typename>
struct ValueTypeInfoMaker;
Expand Down
28 changes: 25 additions & 3 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -401,12 +401,21 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target
auto opt_mixed = transform::Sequential(mixed_pass_list);
mod_mixed = opt_mixed(std::move(mod_mixed));

// We make an assumption here that the overriden host target
// can be used alongside the default host codegen based on device type
// this is so the correct code generator is used later instead of overriding the target.
// We need better support for inserting multiple kDLCPU targets as our current options
// are kDeviceKernelLaunch or not
Target overriden_host_target = target_host;
if (target->kind->device_type == target_host->kind->device_type) {
overriden_host_target = target;
}
auto host_pass_list = {
Filter([](const tir::PrimFunc& f) {
return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) !=
CallingConv::kDeviceKernelLaunch;
}),
BindTarget(target_host),
BindTarget(overriden_host_target),
tir::transform::LowerTVMBuiltin(),
tir::transform::LowerCustomDatatypes(),
tir::transform::LowerIntrin(),
Expand Down Expand Up @@ -487,15 +496,27 @@ runtime::Module build(const Map<Target, IRModule>& inputs_arg, const Target& tar

for (const auto& it : inputs) {
if (it.second.defined()) {
auto pair = SplitDevHostFuncs(it.second, it.first, target_host, pass_ctx);
const Target& target = it.first;
const IRModule& ir_module = it.second;
auto pair = SplitDevHostFuncs(ir_module, target, target_host, pass_ctx);
auto& mhost = pair.first;
auto& mdevice = pair.second;

ICHECK(mhost.defined()) << "The split host module must be defined";

ICHECK(mhost_all.defined()) << "The host module must be defined";

mhost_all->Update(mhost);
// We don't want library modules going back into host codegen
// unless they're supposed to. Here if we overrode the target host
// to allow lowering previously we check that it's meant to be placed
// back into the host Module.
bool overrides_host_target = target->kind->device_type == target_host->kind->device_type;
bool non_host_target_kind = target->kind != target_host->kind;
if (overrides_host_target && non_host_target_kind) {
device_modules.push_back(codegen::Build(mhost, it.first));
} else {
mhost_all->Update(mhost);
}

if (mdevice->functions.size() != 0) {
device_modules.push_back(codegen::Build(mdevice, it.first));
Expand All @@ -510,6 +531,7 @@ runtime::Module build(const Map<Target, IRModule>& inputs_arg, const Target& tar
mhost.Import(it);
}
}

return mhost;
}

Expand Down
19 changes: 13 additions & 6 deletions src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ namespace example_target_hooks {
class ConvertAddToSubtract : public MixedModeMutator {
public:
explicit ConvertAddToSubtract(IRModule ir_module, Target host_target)
: ir_module_(ir_module), host_target_(host_target) {}
: ir_module_(ir_module),
host_target_(host_target),
custom_target_(Target("example_target_hook")) {}

IRModule Mutate() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
Expand Down Expand Up @@ -81,7 +83,15 @@ class ConvertAddToSubtract : public MixedModeMutator {

tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(),
buffer_map, DictAttrs(dict_attrs));
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_);

// Switch to TIRToRuntime hook for testing
Bool tir_to_runtime = func->GetAttr<Bool>("tir_to_runtime").value_or(Bool(false));
if (tir_to_runtime) {
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, custom_target_);
} else {
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_);
}

ir_module_->Add(new_global_var, replacement_func);
}

Expand Down Expand Up @@ -109,6 +119,7 @@ class ConvertAddToSubtract : public MixedModeMutator {
public:
IRModule ir_module_;
Target host_target_;
Target custom_target_;
};

transform::Pass RelayToTIR() {
Expand All @@ -124,8 +135,4 @@ transform::Pass RelayToTIR() {
} // namespace contrib
} // namespace relay

TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU)
.set_attr<tvm::transform::Pass>("RelayToTIR",
relay::contrib::example_target_hooks::RelayToTIR());

} // namespace tvm
39 changes: 39 additions & 0 deletions src/relay/backend/contrib/example_target_hooks/target.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/relay/transform.h>
#include <tvm/target/target.h>

namespace tvm {

namespace relay {
namespace contrib {
namespace example_target_hooks {
tvm::transform::Pass RelayToTIR();
runtime::Module TIRToRuntime(IRModule mod, Target target);
} // namespace example_target_hooks
} // namespace contrib
} // namespace relay

TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU)
.set_attr<FTVMRelayToTIR>("RelayToTIR", relay::contrib::example_target_hooks::RelayToTIR())
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime);

} // namespace tvm
64 changes: 64 additions & 0 deletions src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <sstream>
#include <string>

#include "../../../../target/source/codegen_c_host.h"

namespace tvm {
namespace relay {
namespace contrib {
namespace example_target_hooks {

using namespace tir;

class CodeGenExampleTargetHook : public codegen::CodeGenCHost {
public:
/*!
* \brief Emit code that changes adds to multiplies for testing
*/
void VisitExpr_(const SubNode* op, std::ostream& os) final {
os << '(';
PrintExpr(op->a, os);
os << " * ";
PrintExpr(op->b, os);
os << ')';
}
};

runtime::Module TIRToRuntime(IRModule mod, Target target) {
bool output_ssa = false;
bool emit_asserts = false;
CodeGenExampleTargetHook codegen;
Array<String> function_names;
codegen.Init(output_ssa, emit_asserts, target->str());
for (auto kv : mod->functions) {
auto prim_func = Downcast<PrimFunc>(kv.second);
auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
function_names.push_back(global_symbol.value());
codegen.AddFunction(prim_func);
}
std::string code = codegen.Finish();
return codegen::CSourceModuleCreate(code, "c", function_names);
}

} // namespace example_target_hooks
} // namespace contrib
} // namespace relay
} // namespace tvm
5 changes: 5 additions & 0 deletions src/target/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ runtime::Module Build(IRModule mod, Target target) {
mod = tir::transform::SkipAssert()(mod);
}

auto target_attr_map = tvm::TargetKind::GetAttrMap<FTVMTIRToRuntime>("TIRToRuntime");
if (target_attr_map.count(target->kind)) {
return target_attr_map[target->kind](mod, target);
}

// the build function.
std::string build_f_name = "target.build." + target->kind->name;
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_c_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
namespace tvm {
namespace codegen {

class CodeGenCHost final : public CodeGenC {
class CodeGenCHost : public CodeGenC {
public:
CodeGenCHost();
void Init(bool output_ssa, bool emit_asserts, std::string target_str);
Expand Down
23 changes: 23 additions & 0 deletions tests/python/relay/test_target_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,28 @@ def test_tir_external_generation(check_result):
check_result(func, inputs, (8,), x_data - y_data)


@pytest.mark.parametrize("check_result", [check_aot_executor_result, check_graph_executor_result])
def test_runtime_module_generation(check_result):
shape = (8,)
x_data = np.random.randint(255, size=shape).astype("float32")
y_data = np.random.randint(255, size=shape).astype("float32")
inputs = {"x": x_data, "y": y_data}

x0 = relay.var("x0", shape=shape, dtype="float32")
y0 = relay.var("y0", shape=shape, dtype="float32")
z = x0 + y0
func = relay.Function([x0, y0], z)
func = set_external_func_attr(func, "example_target_hook", "replace_add_with_subtract")
# Test hook to trigger TIRToRuntime code generation
func = func.with_attr("tir_to_runtime", True)

x = relay.var("x", shape=(8,), dtype="float32")
y = relay.var("y", shape=(8,), dtype="float32")
call = relay.Call(func, [x, y])
func = IRModule.from_expr(call)

check_result(func, inputs, (8,), x_data * y_data)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 47e8368

Please sign in to comment.