From 03b7cdb53dd205b622ad435cc21479e42cbfbc23 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Thu, 28 Feb 2019 14:36:49 -0800 Subject: [PATCH 1/2] do --- src/relay/pass/well_formed.cc | 11 +++++++++- tests/python/relay/test_ir_well_formed.py | 25 ++++++++++++++++++++--- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/relay/pass/well_formed.cc b/src/relay/pass/well_formed.cc index d9c6b617ca5f..159e073673da 100644 --- a/src/relay/pass/well_formed.cc +++ b/src/relay/pass/well_formed.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include namespace tvm { @@ -12,7 +13,7 @@ namespace relay { //! brief make sure each Var is bind at most once. -class WellFormedChecker : private ExprVisitor { +class WellFormedChecker : private ExprVisitor, PatternVisitor { bool well_formed = true; std::unordered_set s; @@ -39,6 +40,14 @@ class WellFormedChecker : private ExprVisitor { CheckWellFormed(f->body); } + void VisitPattern(const Pattern& p) final { + PatternVisitor::VisitPattern(p); + } + + void VisitVar(const Var& v) final { + Check(v); + } + public: bool CheckWellFormed(const Expr& e) { this->VisitExpr(e); diff --git a/tests/python/relay/test_ir_well_formed.py b/tests/python/relay/test_ir_well_formed.py index 725b2fbd3c3d..7a1eb54dfd1f 100644 --- a/tests/python/relay/test_ir_well_formed.py +++ b/tests/python/relay/test_ir_well_formed.py @@ -1,9 +1,10 @@ import tvm from tvm import relay from tvm.relay.ir_pass import well_formed +from tvm.relay.prelude import Prelude -def test_well_formed(): - x = relay.Var('x') +def test_let(): + x = relay.Var("x") assert well_formed(x) v = relay.Constant(tvm.nd.array(10)) ty = None @@ -18,7 +19,7 @@ def test_well_formed(): def test_tuple(): - x = relay.Var('x') + x = relay.Var("x") assert well_formed(x) v = relay.Constant(tvm.nd.array(10)) let = relay.Let(x, v, x) @@ -30,3 +31,21 @@ def test_tuple(): def test_tuple_get_item(): t = relay.Var('t') assert well_formed(relay.TupleGetItem(t, 2)) + + +def test_adt(): + mod = relay.Module() + p = Prelude(mod) + x = relay.Var("x") + s_case = relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]), x) + default_case = relay.Clause(relay.PatternVar(x), x) + m0 = relay.Match(p.z(), [default_case]) + m1 = relay.Match(p.z(), [s_case, default_case]) + assert well_formed(m0) + assert not well_formed(m1) + +if __name__ == "__main__": + test_let() + test_tuple() + test_tuple_get_item() + test_adt() From 9579e65e20b221a58c1209914caf6d67b75a498c Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Thu, 28 Feb 2019 21:28:41 -0800 Subject: [PATCH 2/2] address comment --- tests/python/relay/test_ir_well_formed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_ir_well_formed.py b/tests/python/relay/test_ir_well_formed.py index 7a1eb54dfd1f..b9e907144785 100644 --- a/tests/python/relay/test_ir_well_formed.py +++ b/tests/python/relay/test_ir_well_formed.py @@ -29,7 +29,7 @@ def test_tuple(): def test_tuple_get_item(): - t = relay.Var('t') + t = relay.Var("t") assert well_formed(relay.TupleGetItem(t, 2))