Skip to content

Commit

Permalink
[Relay] GNF (apache#2492)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame authored and wweic committed Mar 9, 2019
1 parent fcda217 commit 4a14e2a
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 26 deletions.
13 changes: 12 additions & 1 deletion include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,18 @@ struct StructuralHash {
*
* \return expression in A-Normal Form
*/
Expr ToANF(const Expr& e, const Module& mod);
Expr ToANormalForm(const Expr& e, const Module& mod);

/*! \brief Remove let binding and directly share via pointer instead.
*
* It will remove all let binding,
* and turn all of the variable bound by let into direct pointer reference.
*
* \param e the expression.
*
* \return the expression in graph normal form.
*/
Expr ToGraphNormalForm(const Expr& e);

} // namespace relay
} // namespace tvm
Expand Down
19 changes: 17 additions & 2 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def collect_device_annotation_ops(expr):
return _ir_pass.CollectDeviceAnnotationOps(expr)


def to_anf(expr, mod=None):
def to_a_normal_form(expr, mod=None):
"""
Turn Graph Normal Form expression into A Normal Form Expression.
Expand All @@ -513,7 +513,21 @@ def to_anf(expr, mod=None):
expr: tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_anf(expr, mod)
return _ir_pass.to_a_normal_form(expr, mod)


def to_graph_normal_form(expr):
"""Turn A Normal Form expression into Graph Normal Form expression
Parameters
----------
expr : tvm.relay.Expr
The input expression
Returns
-------
expr : tvm.relay.Expr
The output expression
"""
return _ir_pass.to_graph_normal_form(expr)


def gradient(expr, mod=None):
Expand All @@ -534,6 +548,7 @@ def gradient(expr, mod=None):
"""
return _ir_pass.first_order_gradient(expr, mod)


def get_total_mac_number(expr):
"""
Count the number of MACs (multiply-accumulate) of a model
Expand Down
32 changes: 16 additions & 16 deletions src/relay/pass/to_anf.cc → src/relay/pass/to_a_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body)
return Creator(arena).Create(body);
}

Expr ToANF(const Expr& e, const Module& m, std::set<GlobalVar>* gv);
Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv);

struct ScopeNode;
using Scope = std::shared_ptr<ScopeNode>;
Expand Down Expand Up @@ -258,11 +258,11 @@ bool IsPrimitiveFunction(const Expr& e) {

class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
public:
static Expr ToANF(const Expr& e,
const Module& m,
const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::set<GlobalVar>* gv) {
static Expr ToANormalForm(const Expr& e,
const Module& m,
const DependencyGraph& dg,
std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
std::set<GlobalVar>* gv) {
Fill fi(m, dg, node_scope, gv);
return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
}
Expand Down Expand Up @@ -396,7 +396,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
GlobalVar gv = GetRef<GlobalVar>(gvn);
if (visited_->count(gv) == 0) {
visited_->insert(gv);
mod_->Update(gv, Downcast<Function>(relay::ToANF(mod_->Lookup(gv), mod_, visited_)));
mod_->Update(gv, Downcast<Function>(relay::ToANormalForm(mod_->Lookup(gv), mod_, visited_)));
}
return gv;
}
Expand All @@ -423,7 +423,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
}
};

Expr ToANFAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
/* When you lift a lambda, what is inside is also being lift.
*
* So we must determine the scope of the lambda before determining the scope of it's body.
Expand All @@ -446,29 +446,29 @@ Expr ToANFAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
* We do an additional pass to fill all the LetList and we are done.
*/
std::unordered_map<DependencyGraph::Node*, Scope> node_scope = CalcScope(dg);
return Fill::ToANF(e, m, dg, &node_scope, gv);
return Fill::ToANormalForm(e, m, dg, &node_scope, gv);
}

Expr ToANF(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
if (const auto* f = e.as<FunctionNode>()) {
return FunctionNode::make(f->params,
ToANFAux(f->body, m, gv),
ToANormalFormAux(f->body, m, gv),
f->ret_type,
f->type_params,
f->attrs);
} else {
return ToANFAux(e, m, gv);
return ToANormalFormAux(e, m, gv);
}
}

Expr ToANF(const Expr& e, const Module& m) {
Expr ToANormalForm(const Expr& e, const Module& m) {
std::set<GlobalVar> gv;
return ToANF(e, m, &gv);
return ToANormalForm(e, m, &gv);
}

TVM_REGISTER_API("relay._ir_pass.to_anf")
TVM_REGISTER_API("relay._ir_pass.to_a_normal_form")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ToANF(args[0], args[1]);
*ret = ToANormalForm(args[0], args[1]);
});

} // namespace relay
Expand Down
66 changes: 66 additions & 0 deletions src/relay/pass/to_graph_normal_form.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*!
* Copyright (c) 2018 by Contributors
*
* \file to_gnf.cc
*
* \brief Turn A normal form into graph normal form.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include "let_list.h"

namespace tvm {
namespace relay {

class UseVarVisitor : public ExprVisitor {
public:
explicit UseVarVisitor(const Var& v) : v(v) { }

static bool UseVar(const Var& v, const Expr& e) {
UseVarVisitor uv(v);
uv(e);
return uv.use_var;
}

private:
bool use_var = false;
Var v;

void VisitExpr_(const VarNode* vn) override {
use_var = use_var || (v == GetRef<Var>(vn));
}
};

class GNF : public ExprMutator {
private:
std::unordered_map<Var, Expr, NodeHash, NodeEqual> var_map_;
Expr VisitExpr_(const VarNode* vn) override {
Var v = GetRef<Var>(vn);
return var_map_.count(v) == 0 ? v : var_map_.at(v);
}

static bool UseVar(const Var& v, const Expr& e) {
return UseVarVisitor::UseVar(v, e);
}

static Expr WrapRec(const Var& var, const Expr& val) {
return UseVar(var, val) ? LetNode::make(var, val, var) : val;
}

Expr VisitExpr_(const LetNode* ln) override {
var_map_.insert(std::pair<Var, Expr>(ln->var, VisitExpr(WrapRec(ln->var, ln->value))));
return VisitExpr(ln->body);
}
};

Expr ToGraphNormalForm(const Expr& e) {
return GNF()(e);
}

TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ToGraphNormalForm(args[0]);
});

} // namespace relay
} // namespace tvm
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import to_anf, alpha_equal, infer_type
from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay.prelude import Prelude
Expand All @@ -21,7 +21,7 @@ def test_explicit_bound():
z = op.add(y, y)
f = relay.Function([], op.add(z, z))
assert not "let" in f.astext() # assert the values are implicitly bounded
anf = to_anf(f)
anf = to_a_normal_form(f)
assert "let" in anf.astext() # assert the values are explicitly bounded
check_eval(f(), 8.0)
check_eval(anf(), 8.0)
Expand All @@ -35,7 +35,7 @@ def test_order():
x = relay.const(1)
val = x + y * z
check_eval(val, 7.0)
anf = infer_type(to_anf(val))
anf = infer_type(to_a_normal_form(val))
a = relay.Var('a', relay.IncompleteType())
b = relay.Var('b', relay.IncompleteType())
c = relay.Var('c', relay.IncompleteType())
Expand All @@ -54,7 +54,7 @@ def test_order():
def test_if():
cond = relay.const(True)
x = relay.If(cond, relay.const(2), relay.const(3))
anf = infer_type(to_anf(x))
anf = infer_type(to_a_normal_form(x))
a = relay.Var('a', relay.IncompleteType())
b = relay.Var('b', relay.IncompleteType())
c = relay.Var('c', relay.IncompleteType())
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_recursion():
mod[f] = value
check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)
old_f = mod[f]
f = to_anf(f, mod=mod)
f = to_a_normal_form(f, mod=mod)
check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod)


Expand All @@ -111,7 +111,7 @@ def test_ref():
body = relay.Let(iv, relay.RefRead(i), body)
body = relay.Let(i, relay.RefCreate(relay.const(1)), body)
check_eval(body, 3)
check_eval(to_anf(body), 3)
check_eval(to_a_normal_form(body), 3)


# this is an example of using the adt value in python side
Expand All @@ -135,7 +135,7 @@ def test_add():
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
assert count(intrp.evaluate(add(s(z()), s(z())))) == 2
assert count(intrp.evaluate(to_anf(add(s(z()), s(z())), mod))) == 2
assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
assert "let" in mod[add].astext()

if __name__ == '__main__':
Expand Down
51 changes: 51 additions & 0 deletions tests/python/relay/test_to_graph_normal_form.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import to_graph_normal_form, to_a_normal_form, alpha_equal
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue


def check_eval(expr, args, expected_result, mod=None, rtol=1e-07):
if mod is None:
mod = relay.Module()

ctx = tvm.context("llvm", 0)
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")

result = intrp.evaluate(expr)(*args)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)


def test_implicit_share():
x = relay.Var('x')
y = relay.Var('y')
z = relay.Var('z')
body = relay.Let(z, op.add(y, y), op.add(z, z))
body = relay.Let(y, op.add(x, x), body)
f = relay.Function([], relay.Let(x, relay.const(1), body))
g = to_graph_normal_form(f)
assert "let" in f.astext()
assert not "let" in g.astext()
check_eval(f, [], 8.0)
check_eval(g, [], 8.0)


def test_round_trip():
x = relay.Var('x')
y = relay.Var('y')
z = relay.Var('z')
body = relay.Let(z, op.add(y, y), op.add(z, z))
body = relay.Let(y, op.add(x, x), body)
f = relay.Function([], relay.Let(x, relay.const(1), body))
g = to_graph_normal_form(f)
h = to_a_normal_form(g)
assert "let" in f.astext()
assert not "let" in g.astext()
check_eval(f, [], 8.0)
check_eval(g, [], 8.0)
check_eval(h, [], 8.0)

if __name__ == '__main__':
test_implicit_share()
test_round_trip()

0 comments on commit 4a14e2a

Please sign in to comment.