From e08a1f1aadaf5eb916929a7ed3009ae3131eb955 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Sun, 6 Oct 2019 12:06:05 +0800 Subject: [PATCH] dicrease the complexity of CalcDep from exponential to linear --- src/relay/pass/dead_code.cc | 6 +++++- tests/python/relay/test_pass_dead_code_elimination.py | 9 +++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index b401dad333106..470bdf6d44615 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -110,7 +110,11 @@ class CalcDep : private ExprVisitor { VarMap use_map_; void VisitExpr(const Expr& e) final { - return ExprFunctor::VisitExpr(e); + visit_counter_[e.get()]++; + if (visit_counter_[e.get()] <= 2) { + using TParent = ExprFunctor; + TParent::VisitExpr(e); + } } void VisitExpr_(const LetNode* l) final { diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index ae848f0144487..89bae1f71b47c 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -19,7 +19,9 @@ from tvm.relay import Function, transform from tvm.relay.analysis import alpha_equal, graph_equal, free_vars, assert_alpha_equal from tvm.relay.op import log, add, equal, subtract +from tvm.relay.testing import inception_v3 +import pytest class env: def __init__(self): @@ -129,6 +131,12 @@ def test_tuple_get_item(): assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) +@pytest.mark.timeout(timeout=10, method="thread") +def test_complexity(): + g = inception_v3.get_net(1, 1000, (3, 299, 299), 'float32') + run_opt_pass(g, transform.DeadCodeElimination()) + + if __name__ == "__main__": test_let() test_used_let() @@ -138,3 +146,4 @@ def test_tuple_get_item(): test_recursion_dead() test_op_let() test_tuple_get_item() + test_complexity()