Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Fix Partial Evaluator, Add stricter checking for CheckWellFormed #3749

Merged
merged 6 commits into from
Aug 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) {
}

Clause ExprMutator::VisitClause(const Clause& c) {
return ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs));
Pattern p = VisitPattern(c->lhs);
return ClauseNode::make(p, VisitExpr(c->rhs));
}

Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
Expand Down Expand Up @@ -395,7 +396,9 @@ class ExprBinder : public ExprMutator, PatternMutator {
}

Var VisitVar(const Var& v) final {
return Downcast<Var>(VisitExpr(v));
CHECK(!args_map_.count(v))
<< "Cannnot bind an internal pattern variable";
return v;
}

private:
Expand Down
14 changes: 6 additions & 8 deletions src/relay/pass/de_duplicate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ Expr DeDup(const Expr& e) {
}

Var Fresh(const Var& v) {
CHECK_EQ(rename_.count(v), 0);
CHECK_EQ(memo_.count(v), 0) << v.as<VarNode>();
Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation));
rename_[v] = ret;
return ret;
Expand Down Expand Up @@ -84,18 +86,13 @@ Expr DeDup(const Expr& e) {
}

Pattern VisitPattern(const Pattern& p) final {
return PatternMutator::VisitPattern(p);
return PatternFunctor::VisitPattern(p);
}

Pattern VisitPattern_(const PatternVarNode* op) final {
return PatternVarNode::make(Fresh(op->var));
}

Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
return ClauseNode::make(pat, VisitExpr(c->rhs));
}

Type VisitType_(const TypeVarNode* op) final {
TypeVar v = GetRef<TypeVar>(op);
return type_rename_.count(v) != 0 ? type_rename_.at(v) : v;
Expand All @@ -109,9 +106,10 @@ Expr DeDup(const Expr& e) {
std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_;
std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> type_rename_;
};

CHECK(WellFormed(e)) << AsText(e, false);
Expr ret = DeDupMutator().VisitExpr(e);
CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size());
CHECK(WellFormed(ret));
CHECK_EQ(FreeVars(e).size(), FreeVars(ret).size());
return ret;
}

Expand Down
2 changes: 2 additions & 0 deletions src/relay/pass/let_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#define TVM_RELAY_PASS_LET_LIST_H_

#include <tvm/relay/expr.h>
#include <tvm/relay/analysis.h>
#include <utility>
#include <vector>
#include <tuple>
Expand Down Expand Up @@ -63,6 +64,7 @@ class LetList {
*/
Var Push(Var pv, Expr expr) {
CHECK(!used_);
CHECK(WellFormed(expr));
lets_.emplace_back(std::make_pair(pv, expr));
return pv;
}
Expand Down
90 changes: 49 additions & 41 deletions src/relay/pass/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ class Environment {

void Insert(const Var& v, const PStatic& ps) {
CHECK(ps.defined());
CHECK_GT(env_.size(), 0);
CHECK_EQ(env_.back().locals.count(v), 0);
env_.back().locals[v] = ps;
}
Expand Down Expand Up @@ -604,10 +605,10 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}

PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) {
if (auto* op = e.as<CallNode>()) {
if (op->op.same_as(WithFuncIdOp())) {
CHECK_EQ(op->args.size(), 1);
return VisitExpr(op->args[0], ll, name);
if (const CallNode* c = e.as<CallNode>()) {
if (c->op.same_as(WithFuncIdOp())) {
CHECK_EQ(c->args.size(), 1);
return VisitExpr(c->args[0], ll, name);
}
}
PStatic ret = e.as<FunctionNode>() ?
Expand Down Expand Up @@ -801,34 +802,36 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
LetList* ll) {
return env_.Extend<PStatic>([&]() {
CHECK_EQ(pv.size(), func->params.size());
if (var.as<VarNode>()) {
env_.Insert(Downcast<Var>(var), self);
}
for (size_t i = 0; i < pv.size(); ++i) {
env_.Insert(func->params[i], pv[i]);
}
for (const auto& p : free_vars) {
env_.Insert(p.first, p.second);
}
tvm::Map<TypeVar, Type> subst;
for (size_t i = 0; i < type_args.size(); ++i) {
subst.Set(func->type_params[i], type_args[i]);
}
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
subst.Set(func->type_params[i], IncompleteTypeNode::make(kType));
}
std::vector<Fuel> args_fuel;
for (const auto& v : pv) {
args_fuel.push_back(GetFuel(v));
}
CHECK_GT(func_map_.count(func), 0);
FuncId fid = func_map_.at(func);
if (fuel_map_.count(fid) == 0) {
fuel_map_.insert({fid, MkFTop()});
}
std::vector<Fuel> args_fuel;
for (const auto& v : pv) {
args_fuel.push_back(GetFuel(v));
}
auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel));
if (std::get<1>(meet_res)) {
FuelFrame tf(this, fid, std::get<0>(meet_res));
Expr dedup_func = RegisterFuncId(DeDup(AnnotateFuncId(func)));
Function func = AsFunc(dedup_func);
if (var.as<VarNode>()) {
env_.Insert(Downcast<Var>(var), self);
}
for (size_t i = 0; i < pv.size(); ++i) {
env_.Insert(func->params[i], pv[i]);
}
for (const auto& p : free_vars) {
env_.Insert(p.first, p.second);
}
tvm::Map<TypeVar, Type> subst;
for (size_t i = 0; i < type_args.size(); ++i) {
subst.Set(func->type_params[i], type_args[i]);
}
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
subst.Set(func->type_params[i], IncompleteTypeNode::make(kType));
}
return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
} else {
std::vector<Expr> dyn;
Expand Down Expand Up @@ -979,32 +982,37 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic VisitExpr_(const MatchNode* op, LetList* ll) final {
PStatic ps = VisitExpr(op->data, ll);
return env_.Extend<PStatic>([&]() {
for (const Clause& c : op->clauses) {
switch (VisitPattern(c->lhs, ps)) {
case MatchStatus::Match:
return VisitExpr(c->rhs, ll);
case MatchStatus::NoMatch:
continue;
case MatchStatus::Unknown:
for (const Clause& c : op->clauses) {
switch (VisitPattern(c->lhs, ps)) {
case MatchStatus::Match:
return VisitExpr(c->rhs, ll);
case MatchStatus::NoMatch:
continue;
case MatchStatus::Unknown:
return [&]() {
tvm::Array<Clause> clauses;
for (const Clause& c : op->clauses) {
Expr expr = store_.Extend<Expr>([&]() {
return LetList::With([&](LetList* ll) {
for (const Var& v : BoundVars(c->lhs)) {
env_.Insert(v, NoStatic(v));
}
return VisitExpr(c->rhs, ll)->dynamic;
});
return LetList::With([&](LetList* ll) {
for (const Var& v : BoundVars(c->lhs)) {
env_.Insert(v, NoStatic(v));
}
return VisitExpr(c->rhs, ll)->dynamic;
});
});
clauses.push_back(ClauseNode::make(c->lhs, expr));
}
store_.Invalidate();
return NoStatic(ll->Push(MatchNode::make(ps->dynamic, clauses, op->complete)));
}
}();
default:
LOG(FATAL) << "Unknown MatchStatus";
throw;
}
LOG(FATAL) << "No case Match";
throw;
});
}
LOG(FATAL) << "No case Match";
throw;
});
}

MatchStatus VisitPattern_(const PatternWildcardNode* op, const PStatic& ps) final {
Expand Down
6 changes: 5 additions & 1 deletion src/relay/pass/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,11 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
private:
const tvm::Map<TypeVar, Type>& subst_map_;
};
return TypeSubstMutator(subst_map).VisitExpr(expr);
CHECK(WellFormed(expr));
auto ret = TypeSubstMutator(subst_map).VisitExpr(expr);
CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
CHECK(WellFormed(ret));
return ret;
}

} // namespace relay
Expand Down
62 changes: 55 additions & 7 deletions src/relay/pass/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,84 @@ namespace relay {
class WellFormedChecker : private ExprVisitor, PatternVisitor {
bool well_formed = true;

std::unordered_set<Var, NodeHash, NodeEqual> s;
std::vector<std::unordered_set<Var, NodeHash, NodeEqual>> scope;
std::unordered_set<Var, NodeHash, NodeEqual> current_bound;
std::unordered_set<Var, NodeHash, NodeEqual> total_bound;
std::unordered_set<Var, NodeHash, NodeEqual> free;

void Check(const Var& v) {
if (s.count(v) != 0) {
struct Scope {
WellFormedChecker* wfc;
explicit Scope(WellFormedChecker* wfc) : wfc(wfc) {
wfc->scope.push_back({});
}
~Scope() {
CHECK_GE(wfc->scope.size(), 0);
for (const Var& v : wfc->scope.back()) {
CHECK_GE(wfc->current_bound.count(v), 0);
wfc->current_bound.erase(v);
}
wfc->scope.pop_back();
}
};

void Bound(const Var& v) {
if (current_bound.count(v) != 0 || total_bound.count(v) != 0 || free.count(v) != 0) {
well_formed = false;
}
s.insert(v);
CHECK_GE(scope.size(), 0);
scope.back().insert(v);
current_bound.insert(v);
total_bound.insert(v);
}

void VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
if (current_bound.count(v) == 0) {
if (total_bound.count(v) != 0) {
well_formed = false;
} else {
free.insert(v);
}
}
}

void VisitExpr_(const LetNode* l) final {
Scope s(this);
// we do letrec only for FunctionNode,
// but shadowing let in let binding is likely programming error, and we should forbidden it.
Check(l->var);
Bound(l->var);
CheckWellFormed(l->value);
CheckWellFormed(l->body);
}

void VisitExpr_(const FunctionNode* f) final {
Scope s(this);
for (const Var& param : f->params) {
Check(param);
Bound(param);
}
CheckWellFormed(f->body);
}

void VisitClause(const Clause& c) final {
Scope s(this);
VisitPattern(c->lhs);
VisitExpr(c->rhs);
}

void VisitPattern(const Pattern& p) final {
PatternVisitor::VisitPattern(p);
}

void VisitVar(const Var& v) final {
Check(v);
Bound(v);
}

void VisitExpr(const Expr& e) final {
if (auto v = e.as<VarNode>()) {
VisitExpr_(v);
} else {
ExprVisitor::VisitExpr(e);
}
}

public:
Expand Down
15 changes: 12 additions & 3 deletions tests/python/relay/test_error_reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,36 @@ def check_type_err(expr, msg):
except tvm.TVMError as err:
assert msg in str(err)

def test_wellformed():
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x)
check_type_err(
f(x),
"Check failed: WellFormed")

def test_too_many_args():
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x)
y = relay.var('y', shape=(10, 10))
check_type_err(
f(x, y),
f(y, y),
"the function is provided too many arguments expected 1, found 2;")

def test_too_few_args():
x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(10, 10))
z = relay.var('z', shape=(10, 10))
f = relay.Function([x, y], x)
check_type_err(f(x), "the function is provided too few arguments expected 2, found 1;")
check_type_err(f(z), "the function is provided too few arguments expected 2, found 1;")

def test_rel_fail():
x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(11, 10))
f = relay.Function([x, y], x + y)
check_type_err(f(x, y), "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);")
check_type_err(f, "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);")

if __name__ == "__main__":
test_wellformed()
test_too_many_args()
test_too_few_args()
test_rel_fail()
9 changes: 9 additions & 0 deletions tests/python/relay/test_pass_partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,16 @@ def test_triangle_number():
assert_alpha_equal(dcpe(orig), const(55))


def test_nat_update():
m = Module()
p = Prelude(m)
add_nat_definitions(p)
m = transform.ToANormalForm()(m)
transform.PartialEvaluate()(m)


if __name__ == '__main__':
test_nat_update()
test_ref()
test_tuple()
test_empty_ad()
Expand Down