From b118b30edb2e07b442eefa2ceeeea63c63c49a71 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 13 Nov 2019 18:57:51 -0800 Subject: [PATCH 1/6] [Relay][Pass] Add pass to remove unused functions in relay module --- python/tvm/relay/transform.py | 9 ++ src/relay/backend/vm/compiler.cc | 2 + src/relay/backend/vm/removed_unused_funcs.cc | 120 +++++++++++++++++++ 3 files changed, 131 insertions(+) create mode 100644 src/relay/backend/vm/removed_unused_funcs.cc diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index d3509dabddf9..b693555c6f4c 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -297,6 +297,15 @@ def BackwardFoldScaleAxis(): """ return _transform.BackwardFoldScaleAxis() +def RemoveUnusedFunctions(): + """Remove unused global relay functions in a relay module. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass to remove unused functions. + """ + return _transform.RemoveUnusedFunctions() def ForwardFoldScaleAxis(): """Fold the scaling of axis into weights of conv2d/dense. diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 7f828c473bbe..947c9dc64917 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -54,6 +54,7 @@ namespace transform { Pass LambdaLift(); Pass InlinePrimitives(); +Pass RemoveUnusedFunctions(); Pass ManifestAlloc(Target target_host) { auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc"); @@ -863,6 +864,7 @@ void VMCompiler::Compile(Module mod, Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) { Array pass_seqs; + pass_seqs.push_back(transform::RemoveUnusedFunctions()); // Run all dialect legalization passes. pass_seqs.push_back(relay::qnn::transform::Legalize()); diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc new file mode 100644 index 000000000000..54356bc6cc6e --- /dev/null +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/relay/backend/vm/remove_unused_funcs.cc + * \brief Remove unused global relay functions in a relay module. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace vm { + +/** + * \brief Detects all the functions that can be possibly called by entry function. + */ +struct CallTracer : ExprVisitor { + Module module_; + + // Record the names of all encountered functions + std::unordered_set called_funcs_; + + // Remember the functions seen to avoid infinite loop + std::unordered_set visited_; + + explicit CallTracer(const Module& module) + : module_{module}, + called_funcs_{}, + visited_{} {} + + void VisitExpr_(const CallNode* call_node) final { + Expr op = call_node->op; + if (auto func_node = op.as()) { + auto func = GetRef(func_node); + auto it = visited_.find(func); + if (it != visited_.end()) { + return; + } + VisitExpr(func); + visited_.insert(func); + } else if (auto global = op.as()) { + called_funcs_.insert(global->name_hint); + auto func = module_->Lookup(global->name_hint); + auto it = visited_.find(func); + if (it != visited_.end()) { + return; + } + VisitExpr(func); + visited_.insert(func); + } + for (auto param: call_node->args) { + VisitExpr(param); + } + } + + std::unordered_set Trace(const std::string& entry) { + called_funcs_.insert(entry); + auto main_func = module_->Lookup(entry); + VisitExpr(main_func); + return called_funcs_; + } +}; + +Module RemoveUnusedFunctions(const Module& module) { + auto called_funcs = CallTracer(module).Trace("main"); + auto existing_functions = module->functions; + for (auto f : existing_functions) { + auto it = called_funcs.find(f.first->name_hint); + if (it == called_funcs.end()) { + module->Remove(f.first); + } + } + return module; +} + +} // namespace vm + +namespace transform { + +Pass RemoveUnusedFunctions() { + runtime::TypedPackedFunc pass_func = + [=](Module m, PassContext pc) { + return relay::vm::RemoveUnusedFunctions(m); + }; + return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {}); +} + +TVM_REGISTER_API("relay._transform.RemoveUnusedFunctions") +.set_body_typed(RemoveUnusedFunctions); + +} // namespace transform + +} // namespace relay +} // namespace tvm From 08fe41dd3cec1710c641c699772adb15e6dbaf24 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 13 Nov 2019 21:04:35 -0800 Subject: [PATCH 2/6] Add tests --- .../test_pass_remove_unused_functions.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 tests/python/relay/test_pass_remove_unused_functions.py diff --git a/tests/python/relay/test_pass_remove_unused_functions.py b/tests/python/relay/test_pass_remove_unused_functions.py new file mode 100644 index 000000000000..c9fd2762dae3 --- /dev/null +++ b/tests/python/relay/test_pass_remove_unused_functions.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.relay.prelude import Prelude + +def test_remove_all_prelude_functions(): + mod = relay.Module() + p = Prelude(mod) + x = relay.var("x", shape=(1, 16)) + mod["main"] = relay.Function([x], x) + mod = relay.transform.RemoveUnusedFunctions()(mod) + # Keep: main + assert len(mod.functions) == 1 + +def test_remove_all_prelude_functions_but_referenced_functions(): + mod = relay.Module() + p = Prelude(mod) + x = relay.var("x", shape=(1, 16)) + id_func = relay.Function([x], x) + id_name = relay.GlobalVar('id_func') + mod[id_name] = id_func + + mod["main"] = relay.Function([x], id_name(x)) + mod = relay.transform.RemoveUnusedFunctions()(mod) + # Keep: id_func, main + assert len(mod.functions) == 2 + +def test_keep_only_referenced_prelude_functions(): + mod = relay.Module() + p = Prelude(mod) + l = p.nil() + for i in [4, 3, 2, 1, 0]: + l = p.cons(relay.const(i), l) + body = p.hd(p.tl(p.tl(l))) + mod["main"] = relay.Function([], body) + mod = relay.transform.RemoveUnusedFunctions()(mod) + # Keep: hd, tl, main + assert len(mod.functions) == 3 + +if __name__ == '__main__': + pytest.main() From 429ce2580144cf4f5c3d544c72b5d4966967cbc1 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Wed, 13 Nov 2019 23:05:30 -0800 Subject: [PATCH 3/6] Fix lint --- src/relay/backend/vm/removed_unused_funcs.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index 54356bc6cc6e..e1a6daa3e117 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -74,7 +74,7 @@ struct CallTracer : ExprVisitor { VisitExpr(func); visited_.insert(func); } - for (auto param: call_node->args) { + for (auto param : call_node->args) { VisitExpr(param); } } From 7e596eae32e547a1a0d630b434f591bceae15526 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Thu, 14 Nov 2019 07:35:13 -0800 Subject: [PATCH 4/6] Fix visit order --- src/relay/backend/vm/removed_unused_funcs.cc | 24 ++++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index e1a6daa3e117..07933d9c83b9 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -46,36 +46,36 @@ struct CallTracer : ExprVisitor { // Record the names of all encountered functions std::unordered_set called_funcs_; - // Remember the functions seen to avoid infinite loop - std::unordered_set visited_; + // Record the expressions that are being visited + std::unordered_set visiting_; explicit CallTracer(const Module& module) : module_{module}, called_funcs_{}, - visited_{} {} + visiting_{} {} void VisitExpr_(const CallNode* call_node) final { Expr op = call_node->op; + for (auto param : call_node->args) { + VisitExpr(param); + } if (auto func_node = op.as()) { auto func = GetRef(func_node); - auto it = visited_.find(func); - if (it != visited_.end()) { + auto it = visiting_.find(func); + if (it != visiting_.end()) { return; } + visiting_.insert(func); VisitExpr(func); - visited_.insert(func); } else if (auto global = op.as()) { called_funcs_.insert(global->name_hint); auto func = module_->Lookup(global->name_hint); - auto it = visited_.find(func); - if (it != visited_.end()) { + auto it = visiting_.find(func); + if (it != visiting_.end()) { return; } + visiting_.insert(func); VisitExpr(func); - visited_.insert(func); - } - for (auto param : call_node->args) { - VisitExpr(param); } } From d5dbb906405dd0949f63c97b5e01885c687b3b33 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Thu, 14 Nov 2019 11:03:58 -0800 Subject: [PATCH 5/6] Add pass argument --- python/tvm/relay/transform.py | 9 ++++-- src/relay/backend/vm/compiler.cc | 7 +++-- src/relay/backend/vm/removed_unused_funcs.cc | 22 +++++++++++--- .../test_pass_remove_unused_functions.py | 30 +++++++++++++++---- 4 files changed, 54 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index b693555c6f4c..9f67cc4d647c 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -297,15 +297,20 @@ def BackwardFoldScaleAxis(): """ return _transform.BackwardFoldScaleAxis() -def RemoveUnusedFunctions(): +def RemoveUnusedFunctions(entry_functions=['main']): """Remove unused global relay functions in a relay module. + Parameters + ---------- + entry_functions: list[string] + The set of entry functions to start from. + Returns ------- ret : tvm.relay.Pass The registered pass to remove unused functions. """ - return _transform.RemoveUnusedFunctions() + return _transform.RemoveUnusedFunctions(entry_functions) def ForwardFoldScaleAxis(): """Fold the scaling of axis into weights of conv2d/dense. diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 947c9dc64917..8b63759b6406 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -54,7 +54,7 @@ namespace transform { Pass LambdaLift(); Pass InlinePrimitives(); -Pass RemoveUnusedFunctions(); +Pass RemoveUnusedFunctions(Array entry_functions); Pass ManifestAlloc(Target target_host) { auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc"); @@ -864,7 +864,10 @@ void VMCompiler::Compile(Module mod, Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) { Array pass_seqs; - pass_seqs.push_back(transform::RemoveUnusedFunctions()); + Array entry_functions{}; + auto f = tvm::Expr{"main"}; + entry_functions.push_back(f); + pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); // Run all dialect legalization passes. pass_seqs.push_back(relay::qnn::transform::Legalize()); diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index 07933d9c83b9..a01204077c55 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -87,8 +87,22 @@ struct CallTracer : ExprVisitor { } }; -Module RemoveUnusedFunctions(const Module& module) { - auto called_funcs = CallTracer(module).Trace("main"); +/*! + * \brief Remove functions that are not used. + * + * \param module The Relay module. + * \param entry_funcs The set of functions that can be entry function. + * + * \return The module with dead functions removed. + */ +Module RemoveUnusedFunctions(const Module& module, + Array entry_funcs) { + std::unordered_set called_funcs{}; + for (auto entry : entry_funcs) { + auto* str_name = entry.as(); + auto funcs = CallTracer(module).Trace(str_name->value); + called_funcs.insert(funcs.cbegin(), funcs.cend()); + } auto existing_functions = module->functions; for (auto f : existing_functions) { auto it = called_funcs.find(f.first->name_hint); @@ -103,10 +117,10 @@ Module RemoveUnusedFunctions(const Module& module) { namespace transform { -Pass RemoveUnusedFunctions() { +Pass RemoveUnusedFunctions(Array entry_functions) { runtime::TypedPackedFunc pass_func = [=](Module m, PassContext pc) { - return relay::vm::RemoveUnusedFunctions(m); + return relay::vm::RemoveUnusedFunctions(m, entry_functions); }; return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {}); } diff --git a/tests/python/relay/test_pass_remove_unused_functions.py b/tests/python/relay/test_pass_remove_unused_functions.py index c9fd2762dae3..c4a0c41bfdd1 100644 --- a/tests/python/relay/test_pass_remove_unused_functions.py +++ b/tests/python/relay/test_pass_remove_unused_functions.py @@ -25,8 +25,8 @@ def test_remove_all_prelude_functions(): x = relay.var("x", shape=(1, 16)) mod["main"] = relay.Function([x], x) mod = relay.transform.RemoveUnusedFunctions()(mod) - # Keep: main - assert len(mod.functions) == 1 + l = set([x[0].name_hint for x in mod.functions.items()]) + assert l == set(['main']) def test_remove_all_prelude_functions_but_referenced_functions(): mod = relay.Module() @@ -38,8 +38,8 @@ def test_remove_all_prelude_functions_but_referenced_functions(): mod["main"] = relay.Function([x], id_name(x)) mod = relay.transform.RemoveUnusedFunctions()(mod) - # Keep: id_func, main - assert len(mod.functions) == 2 + l = set([x[0].name_hint for x in mod.functions.items()]) + assert l == set(['id_func', 'main']) def test_keep_only_referenced_prelude_functions(): mod = relay.Module() @@ -50,8 +50,26 @@ def test_keep_only_referenced_prelude_functions(): body = p.hd(p.tl(p.tl(l))) mod["main"] = relay.Function([], body) mod = relay.transform.RemoveUnusedFunctions()(mod) - # Keep: hd, tl, main - assert len(mod.functions) == 3 + l = set([x[0].name_hint for x in mod.functions.items()]) + assert l == set(['tl', 'hd', 'main']) + +def test_multiple_entry_functions(): + mod = relay.Module() + p = Prelude(mod) + l = p.nil() + for i in [4, 3, 2, 1, 0]: + l = p.cons(relay.const(i), l) + body = p.hd(p.tl(p.tl(l))) + mod["main1"] = relay.Function([], body) + + x = relay.var("x", shape=(1, 16)) + id_func = relay.Function([x], x) + id_name = relay.GlobalVar('id_func') + mod[id_name] = id_func + mod["main2"] = relay.Function([x], id_name(x)) + mod = relay.transform.RemoveUnusedFunctions(['main1', 'main2'])(mod) + l = set([x[0].name_hint for x in mod.functions.items()]) + assert l == set(['tl', 'hd', 'main2', 'id_func', 'main1']) if __name__ == '__main__': pytest.main() From e5904ebb48b1fc27029ec14ce6a4b38c6ebd0f96 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Thu, 14 Nov 2019 11:06:59 -0800 Subject: [PATCH 6/6] Fix --- python/tvm/relay/transform.py | 4 +++- src/relay/backend/vm/compiler.cc | 4 +--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 9f67cc4d647c..0a7512a77d1a 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -297,7 +297,7 @@ def BackwardFoldScaleAxis(): """ return _transform.BackwardFoldScaleAxis() -def RemoveUnusedFunctions(entry_functions=['main']): +def RemoveUnusedFunctions(entry_functions=None): """Remove unused global relay functions in a relay module. Parameters @@ -310,6 +310,8 @@ def RemoveUnusedFunctions(entry_functions=['main']): ret : tvm.relay.Pass The registered pass to remove unused functions. """ + if entry_functions is None: + entry_functions = ['main'] return _transform.RemoveUnusedFunctions(entry_functions) def ForwardFoldScaleAxis(): diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 8b63759b6406..06705b422afa 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -864,9 +864,7 @@ void VMCompiler::Compile(Module mod, Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) { Array pass_seqs; - Array entry_functions{}; - auto f = tvm::Expr{"main"}; - entry_functions.push_back(f); + Array entry_functions{tvm::Expr{"main"}}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); // Run all dialect legalization passes. pass_seqs.push_back(relay::qnn::transform::Legalize());