diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 1a1fc367ecd0..da9f7b8d19b9 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -216,7 +216,8 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) { } Clause ExprMutator::VisitClause(const Clause& c) { - return ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs)); + Pattern p = VisitPattern(c->lhs); + return ClauseNode::make(p, VisitExpr(c->rhs)); } Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } @@ -395,7 +396,9 @@ class ExprBinder : public ExprMutator, PatternMutator { } Var VisitVar(const Var& v) final { - return Downcast(VisitExpr(v)); + CHECK(!args_map_.count(v)) + << "Cannnot bind an internal pattern variable"; + return v; } private: diff --git a/src/relay/pass/de_duplicate.cc b/src/relay/pass/de_duplicate.cc index d5d4f6960653..332803cb71ba 100644 --- a/src/relay/pass/de_duplicate.cc +++ b/src/relay/pass/de_duplicate.cc @@ -44,6 +44,8 @@ Expr DeDup(const Expr& e) { } Var Fresh(const Var& v) { + CHECK_EQ(rename_.count(v), 0); + CHECK_EQ(memo_.count(v), 0) << v.as(); Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation)); rename_[v] = ret; return ret; @@ -84,18 +86,13 @@ Expr DeDup(const Expr& e) { } Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); + return PatternFunctor::VisitPattern(p); } Pattern VisitPattern_(const PatternVarNode* op) final { return PatternVarNode::make(Fresh(op->var)); } - Clause VisitClause(const Clause& c) final { - Pattern pat = VisitPattern(c->lhs); - return ClauseNode::make(pat, VisitExpr(c->rhs)); - } - Type VisitType_(const TypeVarNode* op) final { TypeVar v = GetRef(op); return type_rename_.count(v) != 0 ? type_rename_.at(v) : v; @@ -109,9 +106,10 @@ Expr DeDup(const Expr& e) { std::unordered_map rename_; std::unordered_map type_rename_; }; - + CHECK(WellFormed(e)) << AsText(e, false); Expr ret = DeDupMutator().VisitExpr(e); - CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size()); + CHECK(WellFormed(ret)); + CHECK_EQ(FreeVars(e).size(), FreeVars(ret).size()); return ret; } diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index e90ab12b10bf..94b5ea3ad42a 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -31,6 +31,7 @@ #define TVM_RELAY_PASS_LET_LIST_H_ #include +#include #include #include #include @@ -63,6 +64,7 @@ class LetList { */ Var Push(Var pv, Expr expr) { CHECK(!used_); + CHECK(WellFormed(expr)); lets_.emplace_back(std::make_pair(pv, expr)); return pv; } diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 1ea63e84e7bc..3f92d7af0eda 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -396,6 +396,7 @@ class Environment { void Insert(const Var& v, const PStatic& ps) { CHECK(ps.defined()); + CHECK_GT(env_.size(), 0); CHECK_EQ(env_.back().locals.count(v), 0); env_.back().locals[v] = ps; } @@ -604,10 +605,10 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) { - if (auto* op = e.as()) { - if (op->op.same_as(WithFuncIdOp())) { - CHECK_EQ(op->args.size(), 1); - return VisitExpr(op->args[0], ll, name); + if (const CallNode* c = e.as()) { + if (c->op.same_as(WithFuncIdOp())) { + CHECK_EQ(c->args.size(), 1); + return VisitExpr(c->args[0], ll, name); } } PStatic ret = e.as() ? @@ -801,34 +802,36 @@ class PartialEvaluator : public ExprFunctor LetList* ll) { return env_.Extend([&]() { CHECK_EQ(pv.size(), func->params.size()); - if (var.as()) { - env_.Insert(Downcast(var), self); - } - for (size_t i = 0; i < pv.size(); ++i) { - env_.Insert(func->params[i], pv[i]); - } - for (const auto& p : free_vars) { - env_.Insert(p.first, p.second); - } - tvm::Map subst; - for (size_t i = 0; i < type_args.size(); ++i) { - subst.Set(func->type_params[i], type_args[i]); - } - for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { - subst.Set(func->type_params[i], IncompleteTypeNode::make(kType)); - } - std::vector args_fuel; - for (const auto& v : pv) { - args_fuel.push_back(GetFuel(v)); - } CHECK_GT(func_map_.count(func), 0); FuncId fid = func_map_.at(func); if (fuel_map_.count(fid) == 0) { fuel_map_.insert({fid, MkFTop()}); } + std::vector args_fuel; + for (const auto& v : pv) { + args_fuel.push_back(GetFuel(v)); + } auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel)); if (std::get<1>(meet_res)) { FuelFrame tf(this, fid, std::get<0>(meet_res)); + Expr dedup_func = RegisterFuncId(DeDup(AnnotateFuncId(func))); + Function func = AsFunc(dedup_func); + if (var.as()) { + env_.Insert(Downcast(var), self); + } + for (size_t i = 0; i < pv.size(); ++i) { + env_.Insert(func->params[i], pv[i]); + } + for (const auto& p : free_vars) { + env_.Insert(p.first, p.second); + } + tvm::Map subst; + for (size_t i = 0; i < type_args.size(); ++i) { + subst.Set(func->type_params[i], type_args[i]); + } + for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { + subst.Set(func->type_params[i], IncompleteTypeNode::make(kType)); + } return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll); } else { std::vector dyn; @@ -979,32 +982,37 @@ class PartialEvaluator : public ExprFunctor PStatic VisitExpr_(const MatchNode* op, LetList* ll) final { PStatic ps = VisitExpr(op->data, ll); return env_.Extend([&]() { - for (const Clause& c : op->clauses) { - switch (VisitPattern(c->lhs, ps)) { - case MatchStatus::Match: - return VisitExpr(c->rhs, ll); - case MatchStatus::NoMatch: - continue; - case MatchStatus::Unknown: + for (const Clause& c : op->clauses) { + switch (VisitPattern(c->lhs, ps)) { + case MatchStatus::Match: + return VisitExpr(c->rhs, ll); + case MatchStatus::NoMatch: + continue; + case MatchStatus::Unknown: + return [&]() { tvm::Array clauses; for (const Clause& c : op->clauses) { Expr expr = store_.Extend([&]() { - return LetList::With([&](LetList* ll) { - for (const Var& v : BoundVars(c->lhs)) { - env_.Insert(v, NoStatic(v)); - } - return VisitExpr(c->rhs, ll)->dynamic; - }); + return LetList::With([&](LetList* ll) { + for (const Var& v : BoundVars(c->lhs)) { + env_.Insert(v, NoStatic(v)); + } + return VisitExpr(c->rhs, ll)->dynamic; }); + }); clauses.push_back(ClauseNode::make(c->lhs, expr)); } store_.Invalidate(); return NoStatic(ll->Push(MatchNode::make(ps->dynamic, clauses, op->complete))); - } + }(); + default: + LOG(FATAL) << "Unknown MatchStatus"; + throw; } - LOG(FATAL) << "No case Match"; - throw; - }); + } + LOG(FATAL) << "No case Match"; + throw; + }); } MatchStatus VisitPattern_(const PatternWildcardNode* op, const PStatic& ps) final { diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index e2b71570bd2f..90c3de857329 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -438,7 +438,11 @@ Expr TypeSubst(const Expr& expr, const tvm::Map& subst_map) { private: const tvm::Map& subst_map_; }; - return TypeSubstMutator(subst_map).VisitExpr(expr); + CHECK(WellFormed(expr)); + auto ret = TypeSubstMutator(subst_map).VisitExpr(expr); + CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size()); + CHECK(WellFormed(ret)); + return ret; } } // namespace relay diff --git a/src/relay/pass/well_formed.cc b/src/relay/pass/well_formed.cc index bfe8865ab52f..27b31deb4f96 100644 --- a/src/relay/pass/well_formed.cc +++ b/src/relay/pass/well_formed.cc @@ -35,36 +35,84 @@ namespace relay { class WellFormedChecker : private ExprVisitor, PatternVisitor { bool well_formed = true; - std::unordered_set s; + std::vector> scope; + std::unordered_set current_bound; + std::unordered_set total_bound; + std::unordered_set free; - void Check(const Var& v) { - if (s.count(v) != 0) { + struct Scope { + WellFormedChecker* wfc; + explicit Scope(WellFormedChecker* wfc) : wfc(wfc) { + wfc->scope.push_back({}); + } + ~Scope() { + CHECK_GE(wfc->scope.size(), 0); + for (const Var& v : wfc->scope.back()) { + CHECK_GE(wfc->current_bound.count(v), 0); + wfc->current_bound.erase(v); + } + wfc->scope.pop_back(); + } + }; + + void Bound(const Var& v) { + if (current_bound.count(v) != 0 || total_bound.count(v) != 0 || free.count(v) != 0) { well_formed = false; } - s.insert(v); + CHECK_GE(scope.size(), 0); + scope.back().insert(v); + current_bound.insert(v); + total_bound.insert(v); + } + + void VisitExpr_(const VarNode* op) final { + Var v = GetRef(op); + if (current_bound.count(v) == 0) { + if (total_bound.count(v) != 0) { + well_formed = false; + } else { + free.insert(v); + } + } } void VisitExpr_(const LetNode* l) final { + Scope s(this); // we do letrec only for FunctionNode, // but shadowing let in let binding is likely programming error, and we should forbidden it. - Check(l->var); + Bound(l->var); CheckWellFormed(l->value); CheckWellFormed(l->body); } void VisitExpr_(const FunctionNode* f) final { + Scope s(this); for (const Var& param : f->params) { - Check(param); + Bound(param); } CheckWellFormed(f->body); } + void VisitClause(const Clause& c) final { + Scope s(this); + VisitPattern(c->lhs); + VisitExpr(c->rhs); + } + void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } void VisitVar(const Var& v) final { - Check(v); + Bound(v); + } + + void VisitExpr(const Expr& e) final { + if (auto v = e.as()) { + VisitExpr_(v); + } else { + ExprVisitor::VisitExpr(e); + } } public: diff --git a/tests/python/relay/test_error_reporting.py b/tests/python/relay/test_error_reporting.py index c446f361101d..74e651884803 100644 --- a/tests/python/relay/test_error_reporting.py +++ b/tests/python/relay/test_error_reporting.py @@ -27,27 +27,36 @@ def check_type_err(expr, msg): except tvm.TVMError as err: assert msg in str(err) +def test_wellformed(): + x = relay.var('x', shape=(10, 10)) + f = relay.Function([x], x) + check_type_err( + f(x), + "Check failed: WellFormed") + def test_too_many_args(): x = relay.var('x', shape=(10, 10)) f = relay.Function([x], x) y = relay.var('y', shape=(10, 10)) check_type_err( - f(x, y), + f(y, y), "the function is provided too many arguments expected 1, found 2;") def test_too_few_args(): x = relay.var('x', shape=(10, 10)) y = relay.var('y', shape=(10, 10)) + z = relay.var('z', shape=(10, 10)) f = relay.Function([x, y], x) - check_type_err(f(x), "the function is provided too few arguments expected 2, found 1;") + check_type_err(f(z), "the function is provided too few arguments expected 2, found 1;") def test_rel_fail(): x = relay.var('x', shape=(10, 10)) y = relay.var('y', shape=(11, 10)) f = relay.Function([x, y], x + y) - check_type_err(f(x, y), "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);") + check_type_err(f, "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);") if __name__ == "__main__": + test_wellformed() test_too_many_args() test_too_few_args() test_rel_fail() diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index 63493625b9de..f914f18b797e 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -323,7 +323,16 @@ def test_triangle_number(): assert_alpha_equal(dcpe(orig), const(55)) +def test_nat_update(): + m = Module() + p = Prelude(m) + add_nat_definitions(p) + m = transform.ToANormalForm()(m) + transform.PartialEvaluate()(m) + + if __name__ == '__main__': + test_nat_update() test_ref() test_tuple() test_empty_ad()