Skip to content

Commit

Permalink
Merge pull request #2 from gigiblender/mangle-tir-func
Browse files Browse the repository at this point in the history
[Relax][AOT] Add pass that mangles TIR PrimFunc names
  • Loading branch information
mbaret authored Dec 5, 2022
2 parents 66eae17 + 08b2ec1 commit 55a2475
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 6 deletions.
6 changes: 6 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,12 @@ TVM_DLL Pass ManifestSharedMemoryLocalStage();
*/
TVM_DLL Pass InstrumentProfileIntrinsics();

/*!
* \brief Mangle TIR function names by appending a prefix to avoid symbol collisions.
* \return The pass.
*/
TVM_DLL Pass TIRFuncRename();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
1 change: 1 addition & 0 deletions src/relax/backend/aot/codegen_aot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ runtime::Module Build(IRModule mod, String mod_name, CompilationConfig config, r
}
mod = AOTLowerMain(mod_name, config)(mod);
mod = tir::transform::LegalizePackedCalls()(mod);
mod = tir::transform::TIRFuncRename()(mod);

auto lowered_funcs = tvm::relay::tec::GetPerTargetModules(mod);
auto exec_metadata = tvm::relay::backend::aot::CreateExecutorMetadata(mod, mod_name, executor, workspace_byte_alignment,
Expand Down
18 changes: 12 additions & 6 deletions src/relax/usmp/transform/convert_relax_to_dps.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,18 @@ class ConvertRelaxMainToDPS : public ExprMutator {
for (auto iter = block->bindings.rbegin(); iter != block->bindings.rend(); iter++) {
Binding binding = *iter;
if (const auto* var_binding = binding.as<VarBindingNode>()) {
if (var_binding->value->IsInstance<VarNode>() &&
return_alias_.count(var_binding->var) > 0) {
// Alias. Update alias map and do not emit binding.
return_alias_[runtime::Downcast<Var>(var_binding->value)] =
return_alias_[var_binding->var];
continue;
if (var_binding->value->IsInstance<VarNode>()) {
if (return_alias_.count(var_binding->var) > 0) {
// Alias. Update alias map and do not emit binding.
return_alias_[runtime::Downcast<Var>(var_binding->value)] =
return_alias_[var_binding->var];
continue;
}
if (return_alias_.count(var_binding->var) == 0
&& return_alias_.count(var_binding->value) > 0) {
// Creating an alias for a dead var. Do not emit binding.
continue;
}
}

if (var_binding->value->IsInstance<relay::TupleNode>() &&
Expand Down
128 changes: 128 additions & 0 deletions src/tir/transforms/tir_func_rename.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/


/*!
* \file src/relax/backend/aot/tir_func_rename.cc
* \brief Mangles TIR function names to avoid symbol conflicts.
* Appends "_tvm_gen" to all function names in the IRModule.
*/

#include <utility>

#include "tvm/ir/name_supply.h"
#include "tvm/ir/transform.h"
#include "tvm/tir/builtin.h"
#include "tvm/tir/stmt_functor.h"

namespace tvm {
namespace tir {
namespace aot {

class TIRMangleFuncName : public StmtExprMutator {

public:
explicit TIRMangleFuncName(IRModule mod) : mod_(std::move(mod)) {
ICHECK(mod_->ContainGlobalVar(runtime::symbol::tvm_module_main)) << "Expecting module to have"
<< " symbol " << runtime::symbol::tvm_module_main << " attached.";
auto main_func_gv = mod_->GetGlobalVar(runtime::symbol::tvm_module_main);
NameSupply name_supply = NameSupply("_tvm_gen");
for (auto pair : mod_->functions) {
if (pair.first.same_as(main_func_gv)) {
// Ignore the main function.
continue;
}
auto prim_func = runtime::Downcast<PrimFunc>(pair.second);
auto func_name = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(func_name.defined()) << "Expecting global_symbol attribute to be attached to the"
" function";
name_map_[func_name.value()] = name_supply->FreshName(func_name.value());
}
}

IRModule operator()() {
auto main_func_gv = mod_->GetGlobalVar(runtime::symbol::tvm_module_main);

Map<GlobalVar, BaseFunc> func_map = Map<GlobalVar, BaseFunc>();
for (auto pair : mod_->functions) {
auto prim_func = runtime::Downcast<PrimFunc>(pair.second);
auto func_name = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);

Stmt new_body = this->VisitStmt(prim_func->body);
if (pair.first.same_as(main_func_gv)) {
// No need to set a new global var and global symbol for the main function.
func_map.Set(pair.first, PrimFunc(prim_func->params, new_body, prim_func->ret_type,
prim_func->buffer_map, prim_func->attrs, prim_func->span));
} else {
ICHECK(name_map_.count(func_name.value()) > 0) << "Expecting new name in name_map_ at "
"this stage.";
GlobalVar new_var = GlobalVar(name_map_[func_name.value()]);
PrimFunc new_func = PrimFunc(prim_func->params, new_body, prim_func->ret_type,
prim_func->buffer_map, prim_func->attrs, prim_func->span);
new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol,
String(name_map_[func_name.value()]));
func_map.Set(new_var, new_func);
}
}

IRModule new_mod = IRModule(func_map, mod_->type_definitions, mod_->Imports(),
mod_->source_map, mod_->attrs);
return new_mod;
}

private:
PrimExpr VisitExpr_(const CallNode* op) override {
String func_name;
if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) {
func_name = Downcast<StringImm>(op->args[0])->value;
}
if (op->op->IsInstance<PrimFuncNode>()) {
func_name = Downcast<StringImm>(op->args[0])->value;
}
if (func_name.defined() && mod_->ContainGlobalVar(func_name) &&
mod_->Lookup(func_name)->IsInstance<PrimFuncNode>()) {
ICHECK(name_map_.count(func_name) > 0) << "Name map should contain a name.";
StringImm new_name = StringImm(name_map_[func_name]);
Array<PrimExpr> new_args = { new_name };
new_args.insert(new_args.end(), op->args.begin() + 1, op->args.end());
return Call(op->dtype, op->op, new_args, op->span);
}
return StmtExprMutator::VisitExpr_(op);
}

std::unordered_map<std::string, std::string> name_map_;
IRModule mod_;
};

} // namespace aot

namespace transform {

tvm::transform::Pass TIRFuncRename() {
auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) {
return runtime::Downcast<IRModule>(tvm::tir::aot::TIRMangleFuncName(m)());
};

return tvm::transform::CreateModulePass(pass_func, 0,
"tir.transform.TIRFuncRename", {});
}

} // namespace transform
} // namespace tir
} // namespace tvm
51 changes: 51 additions & 0 deletions tests/python/relax/test_relax_usmp_convert_to_dps.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,57 @@ def test_tuple_both_alloc():
# tvm.ir.assert_structural_equal(actual_func, ref_func)


# fmt: off
@tvm.script.ir_module
class TestTupleBothAllocDeadCode:
@R.function
def main(input: R.Tensor((16, 16), "uint8")) -> R.Tuple(R.Tensor(None, "float32", ndim = 2), R.Tensor(None, "int32", ndim = 2)):
# block 0
tsid_11 = R.builtin.alloc_tensor((1, 1), dtype="int8", runtime_device_index=0)
alloc = R.builtin.alloc_tensor((5, 7), dtype="float32", runtime_device_index=0)
_ = R.call_packed("prim_func_2", input, tsid_11, alloc, type_args=(R.Tensor(ndim=2, dtype="float32")))
output_1 = alloc

alloc1 = R.builtin.alloc_tensor((5, 7), dtype="int8", runtime_device_index=0)
_1 = R.call_packed("prim_func_3", input, alloc, alloc1, type_args=(R.Tensor(ndim=2, dtype="int8")))
lv0 = alloc1

tsid_12 = R.builtin.alloc_tensor((1, 1), dtype="int8", runtime_device_index=0)
alloc2 = R.builtin.alloc_tensor((802816, 1), dtype="int32", runtime_device_index=0)
_2 = R.call_packed("prim_func_1", input, lv0, tsid_12, alloc2, type_args=(R.Tensor(ndim=2, dtype="int32")))
output_2 = alloc2
output = (alloc, alloc2)
gv = output
return output


@tvm.script.ir_module
class TestTupleBothAllocDeadCodeExpected:
@R.function
def main(input: R.Tensor((16, 16), "uint8"), alloc: R.Tensor((5, 7), "float32"), alloc2: R.Tensor((802816, 1), "int32")):
# block 0
tsid_11 = R.builtin.alloc_tensor((1, 1), dtype="int8", runtime_device_index=0)
_ = R.call_packed("prim_func_2", input, tsid_11, alloc, type_args=(R.Tensor(ndim=2, dtype="float32")))
alloc1 = R.builtin.alloc_tensor((5, 7), dtype="int8", runtime_device_index=0)
_1 = R.call_packed("prim_func_3", input, alloc, alloc1, type_args=(R.Tensor(ndim=2, dtype="int8")))
lv0 = alloc1
tsid_12 = R.builtin.alloc_tensor((1, 1), dtype="int8", runtime_device_index=0)
_2 = R.call_packed("prim_func_1", input, lv0, tsid_12, alloc2, type_args=(R.Tensor(ndim=2, dtype="int32")))
return R.Tuple()

# fmt: on


def test_tuple_both_alloc_dead_code():
before_mod = TestTupleBothAllocDeadCode
after_mod = tvm.relax.transform.ConvertRelaxMainToDPS(attach_io_to_attrs=False)(before_mod)
expected_mod = TestTupleBothAllocDeadCodeExpected
for gv, ref_func in expected_mod.functions.items():
actual_func = after_mod[gv.name_hint]
assert str(actual_func) == str(ref_func)
# tvm.ir.assert_structural_equal(actual_func, ref_func)


# fmt: off
@tvm.script.ir_module
class TestTupleOneAllocOneParam:
Expand Down

0 comments on commit 55a2475

Please sign in to comment.