Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] fix checkwellform #2705

Merged
merged 2 commits into from
Mar 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/relay/pass/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <unordered_set>

namespace tvm {
namespace relay {


//! brief make sure each Var is bind at most once.
class WellFormedChecker : private ExprVisitor {
class WellFormedChecker : private ExprVisitor, PatternVisitor {
bool well_formed = true;

std::unordered_set<Var, NodeHash, NodeEqual> s;
Expand All @@ -39,6 +40,14 @@ class WellFormedChecker : private ExprVisitor {
CheckWellFormed(f->body);
}

void VisitPattern(const Pattern& p) final {
PatternVisitor::VisitPattern(p);
}

void VisitVar(const Var& v) final {
Check(v);
}

public:
bool CheckWellFormed(const Expr& e) {
this->VisitExpr(e);
Expand Down
27 changes: 23 additions & 4 deletions tests/python/relay/test_ir_well_formed.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import tvm
from tvm import relay
from tvm.relay.ir_pass import well_formed
from tvm.relay.prelude import Prelude

def test_well_formed():
x = relay.Var('x')
def test_let():
x = relay.Var("x")
assert well_formed(x)
v = relay.Constant(tvm.nd.array(10))
ty = None
Expand All @@ -18,7 +19,7 @@ def test_well_formed():


def test_tuple():
x = relay.Var('x')
x = relay.Var("x")
assert well_formed(x)
v = relay.Constant(tvm.nd.array(10))
let = relay.Let(x, v, x)
Expand All @@ -28,5 +29,23 @@ def test_tuple():


def test_tuple_get_item():
t = relay.Var('t')
t = relay.Var("t")
assert well_formed(relay.TupleGetItem(t, 2))


def test_adt():
mod = relay.Module()
p = Prelude(mod)
x = relay.Var("x")
s_case = relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]), x)
default_case = relay.Clause(relay.PatternVar(x), x)
m0 = relay.Match(p.z(), [default_case])
m1 = relay.Match(p.z(), [s_case, default_case])
assert well_formed(m0)
assert not well_formed(m1)

if __name__ == "__main__":
test_let()
test_tuple()
test_tuple_get_item()
test_adt()