From 7a492943bfe300862ed535424d9762e444609ffa Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Mon, 18 May 2020 10:12:49 -0700 Subject: [PATCH] fix pattern topological order (#5612) --- src/relay/ir/indexed_graph.cc | 10 ++++--- src/relay/ir/indexed_graph.h | 4 ++- tests/python/relay/test_dataflow_pattern.py | 30 +++++++++++++++++++++ 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index 79ec57426d66..7f7a5ff66853 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -191,10 +191,12 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { protected: void VisitDFPattern(const DFPattern& pattern) override { - DFPatternVisitor::VisitDFPattern(pattern); - auto node = std::make_shared::Node>(pattern, index_++); - graph_.node_map_[pattern] = node; - graph_.topological_order_.push_back(node); + if (this->visited_.count(pattern.get()) == 0) { + DFPatternVisitor::VisitDFPattern(pattern); + auto node = std::make_shared::Node>(pattern, index_++); + graph_.node_map_[pattern] = node; + graph_.topological_order_.push_back(node); + } } IndexedGraph graph_; size_t index_ = 0; diff --git a/src/relay/ir/indexed_graph.h b/src/relay/ir/indexed_graph.h index d2524340f971..022eb3bd5491 100644 --- a/src/relay/ir/indexed_graph.h +++ b/src/relay/ir/indexed_graph.h @@ -69,7 +69,7 @@ class IndexedGraph { std::vector outputs_; /*! \brief The depth of the node in the dominator tree */ - size_t depth_; + size_t depth_ = 0; /*! \brief The dominator parent/final user of the outputs of this node */ Node* dominator_parent_; /*! \brief The nodes this node dominates */ @@ -115,6 +115,8 @@ class IndexedGraph { return nullptr; } while (lhs != rhs) { + CHECK(lhs); + CHECK(rhs); if (lhs->depth_ < rhs->depth_) { rhs = rhs->dominator_parent_; } else if (lhs->depth_ > rhs->depth_) { diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index a93a39be14d0..41b3d6d997e9 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -425,6 +425,35 @@ def callback(self, pre, post, node_map): out = rewrite(TestRewrite(), x + y) assert sub_pattern.match(out) +def test_nested_rewrite(): + class PatternCallback(DFPatternCallback): + def __init__(self, pattern): + self.pattern = pattern + + def callback(self, pre, post, node_map): + return post + + def gen(): + x = relay.var('x') + y = relay.var('y') + y_add = relay.add(y, y) + n0 = relay.add(x, y_add) + n1 = relay.add(x, n0) + return relay.add(n1, n0) + + def pattern(): + a = wildcard() + b = wildcard() + n0 = is_op('add')(a, b) + n1 = is_op('add')(n0, a) + return is_op('add')(n0, n1) + + out = gen() + pat = pattern() + new_out = rewrite(PatternCallback(pat), out) + + assert tvm.ir.structural_equal(out, new_out) + def test_not_fuse_multi_diamond(): # Pattern is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) @@ -838,6 +867,7 @@ def test_parition_double_batchnorm(): test_no_match_diamond() test_match_fake_diamond() test_rewrite() + test_nested_rewrite() test_fuse_batchnorm() test_no_fuse_batchnorm() test_fuse_double_batchnorm()