Skip to content

Commit

Permalink
[Relay] Fix Partial Evaluator, Add stricter checking for CheckWellFor…
Browse files Browse the repository at this point in the history
…med (apache#3749)

* aot

* save

* save

* fix test

* remove vta changes

* lint
  • Loading branch information
MarisaKirisame authored and wweic committed Sep 6, 2019
1 parent beecbee commit c87ec1a
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 62 deletions.
7 changes: 5 additions & 2 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) {
}

Clause ExprMutator::VisitClause(const Clause& c) {
return ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs));
Pattern p = VisitPattern(c->lhs);
return ClauseNode::make(p, VisitExpr(c->rhs));
}

Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
Expand Down Expand Up @@ -395,7 +396,9 @@ class ExprBinder : public ExprMutator, PatternMutator {
}

Var VisitVar(const Var& v) final {
return Downcast<Var>(VisitExpr(v));
CHECK(!args_map_.count(v))
<< "Cannnot bind an internal pattern variable";
return v;
}

private:
Expand Down
14 changes: 6 additions & 8 deletions src/relay/pass/de_duplicate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ Expr DeDup(const Expr& e) {
}

Var Fresh(const Var& v) {
CHECK_EQ(rename_.count(v), 0);
CHECK_EQ(memo_.count(v), 0) << v.as<VarNode>();
Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation));
rename_[v] = ret;
return ret;
Expand Down Expand Up @@ -84,18 +86,13 @@ Expr DeDup(const Expr& e) {
}

Pattern VisitPattern(const Pattern& p) final {
return PatternMutator::VisitPattern(p);
return PatternFunctor::VisitPattern(p);
}

Pattern VisitPattern_(const PatternVarNode* op) final {
return PatternVarNode::make(Fresh(op->var));
}

Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
return ClauseNode::make(pat, VisitExpr(c->rhs));
}

Type VisitType_(const TypeVarNode* op) final {
TypeVar v = GetRef<TypeVar>(op);
return type_rename_.count(v) != 0 ? type_rename_.at(v) : v;
Expand All @@ -109,9 +106,10 @@ Expr DeDup(const Expr& e) {
std::unordered_map<Var, Var, NodeHash, NodeEqual> rename_;
std::unordered_map<TypeVar, TypeVar, NodeHash, NodeEqual> type_rename_;
};

CHECK(WellFormed(e)) << AsText(e, false);
Expr ret = DeDupMutator().VisitExpr(e);
CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size());
CHECK(WellFormed(ret));
CHECK_EQ(FreeVars(e).size(), FreeVars(ret).size());
return ret;
}

Expand Down
2 changes: 2 additions & 0 deletions src/relay/pass/let_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#define TVM_RELAY_PASS_LET_LIST_H_

#include <tvm/relay/expr.h>
#include <tvm/relay/analysis.h>
#include <utility>
#include <vector>
#include <tuple>
Expand Down Expand Up @@ -63,6 +64,7 @@ class LetList {
*/
Var Push(Var pv, Expr expr) {
CHECK(!used_);
CHECK(WellFormed(expr));
lets_.emplace_back(std::make_pair(pv, expr));
return pv;
}
Expand Down
90 changes: 49 additions & 41 deletions src/relay/pass/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ class Environment {

void Insert(const Var& v, const PStatic& ps) {
CHECK(ps.defined());
CHECK_GT(env_.size(), 0);
CHECK_EQ(env_.back().locals.count(v), 0);
env_.back().locals[v] = ps;
}
Expand Down Expand Up @@ -604,10 +605,10 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}

PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) {
if (auto* op = e.as<CallNode>()) {
if (op->op.same_as(WithFuncIdOp())) {
CHECK_EQ(op->args.size(), 1);
return VisitExpr(op->args[0], ll, name);
if (const CallNode* c = e.as<CallNode>()) {
if (c->op.same_as(WithFuncIdOp())) {
CHECK_EQ(c->args.size(), 1);
return VisitExpr(c->args[0], ll, name);
}
}
PStatic ret = e.as<FunctionNode>() ?
Expand Down Expand Up @@ -801,34 +802,36 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
LetList* ll) {
return env_.Extend<PStatic>([&]() {
CHECK_EQ(pv.size(), func->params.size());
if (var.as<VarNode>()) {
env_.Insert(Downcast<Var>(var), self);
}
for (size_t i = 0; i < pv.size(); ++i) {
env_.Insert(func->params[i], pv[i]);
}
for (const auto& p : free_vars) {
env_.Insert(p.first, p.second);
}
tvm::Map<TypeVar, Type> subst;
for (size_t i = 0; i < type_args.size(); ++i) {
subst.Set(func->type_params[i], type_args[i]);
}
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
subst.Set(func->type_params[i], IncompleteTypeNode::make(kType));
}
std::vector<Fuel> args_fuel;
for (const auto& v : pv) {
args_fuel.push_back(GetFuel(v));
}
CHECK_GT(func_map_.count(func), 0);
FuncId fid = func_map_.at(func);
if (fuel_map_.count(fid) == 0) {
fuel_map_.insert({fid, MkFTop()});
}
std::vector<Fuel> args_fuel;
for (const auto& v : pv) {
args_fuel.push_back(GetFuel(v));
}
auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel));
if (std::get<1>(meet_res)) {
FuelFrame tf(this, fid, std::get<0>(meet_res));
Expr dedup_func = RegisterFuncId(DeDup(AnnotateFuncId(func)));
Function func = AsFunc(dedup_func);
if (var.as<VarNode>()) {
env_.Insert(Downcast<Var>(var), self);
}
for (size_t i = 0; i < pv.size(); ++i) {
env_.Insert(func->params[i], pv[i]);
}
for (const auto& p : free_vars) {
env_.Insert(p.first, p.second);
}
tvm::Map<TypeVar, Type> subst;
for (size_t i = 0; i < type_args.size(); ++i) {
subst.Set(func->type_params[i], type_args[i]);
}
for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
subst.Set(func->type_params[i], IncompleteTypeNode::make(kType));
}
return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
} else {
std::vector<Expr> dyn;
Expand Down Expand Up @@ -979,32 +982,37 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic VisitExpr_(const MatchNode* op, LetList* ll) final {
PStatic ps = VisitExpr(op->data, ll);
return env_.Extend<PStatic>([&]() {
for (const Clause& c : op->clauses) {
switch (VisitPattern(c->lhs, ps)) {
case MatchStatus::Match:
return VisitExpr(c->rhs, ll);
case MatchStatus::NoMatch:
continue;
case MatchStatus::Unknown:
for (const Clause& c : op->clauses) {
switch (VisitPattern(c->lhs, ps)) {
case MatchStatus::Match:
return VisitExpr(c->rhs, ll);
case MatchStatus::NoMatch:
continue;
case MatchStatus::Unknown:
return [&]() {
tvm::Array<Clause> clauses;
for (const Clause& c : op->clauses) {
Expr expr = store_.Extend<Expr>([&]() {
return LetList::With([&](LetList* ll) {
for (const Var& v : BoundVars(c->lhs)) {
env_.Insert(v, NoStatic(v));
}
return VisitExpr(c->rhs, ll)->dynamic;
});
return LetList::With([&](LetList* ll) {
for (const Var& v : BoundVars(c->lhs)) {
env_.Insert(v, NoStatic(v));
}
return VisitExpr(c->rhs, ll)->dynamic;
});
});
clauses.push_back(ClauseNode::make(c->lhs, expr));
}
store_.Invalidate();
return NoStatic(ll->Push(MatchNode::make(ps->dynamic, clauses, op->complete)));
}
}();
default:
LOG(FATAL) << "Unknown MatchStatus";
throw;
}
LOG(FATAL) << "No case Match";
throw;
});
}
LOG(FATAL) << "No case Match";
throw;
});
}

MatchStatus VisitPattern_(const PatternWildcardNode* op, const PStatic& ps) final {
Expand Down
6 changes: 5 additions & 1 deletion src/relay/pass/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,11 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
private:
const tvm::Map<TypeVar, Type>& subst_map_;
};
return TypeSubstMutator(subst_map).VisitExpr(expr);
CHECK(WellFormed(expr));
auto ret = TypeSubstMutator(subst_map).VisitExpr(expr);
CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
CHECK(WellFormed(ret));
return ret;
}

} // namespace relay
Expand Down
62 changes: 55 additions & 7 deletions src/relay/pass/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,84 @@ namespace relay {
class WellFormedChecker : private ExprVisitor, PatternVisitor {
bool well_formed = true;

std::unordered_set<Var, NodeHash, NodeEqual> s;
std::vector<std::unordered_set<Var, NodeHash, NodeEqual>> scope;
std::unordered_set<Var, NodeHash, NodeEqual> current_bound;
std::unordered_set<Var, NodeHash, NodeEqual> total_bound;
std::unordered_set<Var, NodeHash, NodeEqual> free;

void Check(const Var& v) {
if (s.count(v) != 0) {
struct Scope {
WellFormedChecker* wfc;
explicit Scope(WellFormedChecker* wfc) : wfc(wfc) {
wfc->scope.push_back({});
}
~Scope() {
CHECK_GE(wfc->scope.size(), 0);
for (const Var& v : wfc->scope.back()) {
CHECK_GE(wfc->current_bound.count(v), 0);
wfc->current_bound.erase(v);
}
wfc->scope.pop_back();
}
};

void Bound(const Var& v) {
if (current_bound.count(v) != 0 || total_bound.count(v) != 0 || free.count(v) != 0) {
well_formed = false;
}
s.insert(v);
CHECK_GE(scope.size(), 0);
scope.back().insert(v);
current_bound.insert(v);
total_bound.insert(v);
}

void VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
if (current_bound.count(v) == 0) {
if (total_bound.count(v) != 0) {
well_formed = false;
} else {
free.insert(v);
}
}
}

void VisitExpr_(const LetNode* l) final {
Scope s(this);
// we do letrec only for FunctionNode,
// but shadowing let in let binding is likely programming error, and we should forbidden it.
Check(l->var);
Bound(l->var);
CheckWellFormed(l->value);
CheckWellFormed(l->body);
}

void VisitExpr_(const FunctionNode* f) final {
Scope s(this);
for (const Var& param : f->params) {
Check(param);
Bound(param);
}
CheckWellFormed(f->body);
}

void VisitClause(const Clause& c) final {
Scope s(this);
VisitPattern(c->lhs);
VisitExpr(c->rhs);
}

void VisitPattern(const Pattern& p) final {
PatternVisitor::VisitPattern(p);
}

void VisitVar(const Var& v) final {
Check(v);
Bound(v);
}

void VisitExpr(const Expr& e) final {
if (auto v = e.as<VarNode>()) {
VisitExpr_(v);
} else {
ExprVisitor::VisitExpr(e);
}
}

public:
Expand Down
15 changes: 12 additions & 3 deletions tests/python/relay/test_error_reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,36 @@ def check_type_err(expr, msg):
except tvm.TVMError as err:
assert msg in str(err)

def test_wellformed():
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x)
check_type_err(
f(x),
"Check failed: WellFormed")

def test_too_many_args():
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x)
y = relay.var('y', shape=(10, 10))
check_type_err(
f(x, y),
f(y, y),
"the function is provided too many arguments expected 1, found 2;")

def test_too_few_args():
x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(10, 10))
z = relay.var('z', shape=(10, 10))
f = relay.Function([x, y], x)
check_type_err(f(x), "the function is provided too few arguments expected 2, found 1;")
check_type_err(f(z), "the function is provided too few arguments expected 2, found 1;")

def test_rel_fail():
x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(11, 10))
f = relay.Function([x, y], x + y)
check_type_err(f(x, y), "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);")
check_type_err(f, "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);")

if __name__ == "__main__":
test_wellformed()
test_too_many_args()
test_too_few_args()
test_rel_fail()
9 changes: 9 additions & 0 deletions tests/python/relay/test_pass_partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,16 @@ def test_triangle_number():
assert_alpha_equal(dcpe(orig), const(55))


def test_nat_update():
m = Module()
p = Prelude(m)
add_nat_definitions(p)
m = transform.ToANormalForm()(m)
transform.PartialEvaluate()(m)


if __name__ == '__main__':
test_nat_update()
test_ref()
test_tuple()
test_empty_ad()
Expand Down

0 comments on commit c87ec1a

Please sign in to comment.