diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index b401dad33310..9e46677366bc 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -110,7 +110,15 @@ class CalcDep : private ExprVisitor { VarMap use_map_; void VisitExpr(const Expr& e) final { - return ExprFunctor::VisitExpr(e); + visit_counter_[e.get()]++; + // The dce code seprate variable into three parts: + // used 0 times (remove) + // used 1 times (inline) + // used 2 times (dont do anything). + if (visit_counter_[e.get()] <= 2) { + using TParent = ExprFunctor; + TParent::VisitExpr(e); + } } void VisitExpr_(const LetNode* l) final { diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index ae848f014448..89bae1f71b47 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -19,7 +19,9 @@ from tvm.relay import Function, transform from tvm.relay.analysis import alpha_equal, graph_equal, free_vars, assert_alpha_equal from tvm.relay.op import log, add, equal, subtract +from tvm.relay.testing import inception_v3 +import pytest class env: def __init__(self): @@ -129,6 +131,12 @@ def test_tuple_get_item(): assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) +@pytest.mark.timeout(timeout=10, method="thread") +def test_complexity(): + g = inception_v3.get_net(1, 1000, (3, 299, 299), 'float32') + run_opt_pass(g, transform.DeadCodeElimination()) + + if __name__ == "__main__": test_let() test_used_let() @@ -138,3 +146,4 @@ def test_tuple_get_item(): test_recursion_dead() test_op_let() test_tuple_get_item() + test_complexity()