diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index b87f9319a3d3..75bfe92ec21c 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -320,7 +320,18 @@ struct StructuralHash { * * \return expression in A-Normal Form */ -Expr ToANF(const Expr& e, const Module& mod); +Expr ToANormalForm(const Expr& e, const Module& mod); + +/*! \brief Remove let binding and directly share via pointer instead. + * + * It will remove all let binding, + * and turn all of the variable bound by let into direct pointer reference. + * + * \param e the expression. + * + * \return the expression in graph normal form. + */ +Expr ToGraphNormalForm(const Expr& e); } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 561c5d388788..02a6e8b5906e 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -490,7 +490,7 @@ def collect_device_annotation_ops(expr): return _ir_pass.CollectDeviceAnnotationOps(expr) -def to_anf(expr, mod=None): +def to_a_normal_form(expr, mod=None): """ Turn Graph Normal Form expression into A Normal Form Expression. @@ -513,7 +513,21 @@ def to_anf(expr, mod=None): expr: tvm.relay.Expr The output expression. """ - return _ir_pass.to_anf(expr, mod) + return _ir_pass.to_a_normal_form(expr, mod) + + +def to_graph_normal_form(expr): + """Turn A Normal Form expression into Graph Normal Form expression + Parameters + ---------- + expr : tvm.relay.Expr + The input expression + Returns + ------- + expr : tvm.relay.Expr + The output expression + """ + return _ir_pass.to_graph_normal_form(expr) def gradient(expr, mod=None): @@ -534,6 +548,7 @@ def gradient(expr, mod=None): """ return _ir_pass.first_order_gradient(expr, mod) + def get_total_mac_number(expr): """ Count the number of MACs (multiply-accumulate) of a model diff --git a/src/relay/pass/to_anf.cc b/src/relay/pass/to_a_normal_form.cc similarity index 93% rename from src/relay/pass/to_anf.cc rename to src/relay/pass/to_a_normal_form.cc index 912774162b51..53e2c1c594f8 100644 --- a/src/relay/pass/to_anf.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -196,7 +196,7 @@ DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body) return Creator(arena).Create(body); } -Expr ToANF(const Expr& e, const Module& m, std::set* gv); +Expr ToANormalForm(const Expr& e, const Module& m, std::set* gv); struct ScopeNode; using Scope = std::shared_ptr; @@ -258,11 +258,11 @@ bool IsPrimitiveFunction(const Expr& e) { class Fill : ExprFunctor { public: - static Expr ToANF(const Expr& e, - const Module& m, - const DependencyGraph& dg, - std::unordered_map* node_scope, - std::set* gv) { + static Expr ToANormalForm(const Expr& e, + const Module& m, + const DependencyGraph& dg, + std::unordered_map* node_scope, + std::set* gv) { Fill fi(m, dg, node_scope, gv); return fi.GetScope(e)->ll->Get(fi.VisitExpr(e)); } @@ -396,7 +396,7 @@ class Fill : ExprFunctor { GlobalVar gv = GetRef(gvn); if (visited_->count(gv) == 0) { visited_->insert(gv); - mod_->Update(gv, Downcast(relay::ToANF(mod_->Lookup(gv), mod_, visited_))); + mod_->Update(gv, Downcast(relay::ToANormalForm(mod_->Lookup(gv), mod_, visited_))); } return gv; } @@ -423,7 +423,7 @@ class Fill : ExprFunctor { } }; -Expr ToANFAux(const Expr& e, const Module& m, std::set* gv) { +Expr ToANormalFormAux(const Expr& e, const Module& m, std::set* gv) { /* When you lift a lambda, what is inside is also being lift. * * So we must determine the scope of the lambda before determining the scope of it's body. @@ -446,29 +446,29 @@ Expr ToANFAux(const Expr& e, const Module& m, std::set* gv) { * We do an additional pass to fill all the LetList and we are done. */ std::unordered_map node_scope = CalcScope(dg); - return Fill::ToANF(e, m, dg, &node_scope, gv); + return Fill::ToANormalForm(e, m, dg, &node_scope, gv); } -Expr ToANF(const Expr& e, const Module& m, std::set* gv) { +Expr ToANormalForm(const Expr& e, const Module& m, std::set* gv) { if (const auto* f = e.as()) { return FunctionNode::make(f->params, - ToANFAux(f->body, m, gv), + ToANormalFormAux(f->body, m, gv), f->ret_type, f->type_params, f->attrs); } else { - return ToANFAux(e, m, gv); + return ToANormalFormAux(e, m, gv); } } -Expr ToANF(const Expr& e, const Module& m) { +Expr ToANormalForm(const Expr& e, const Module& m) { std::set gv; - return ToANF(e, m, &gv); + return ToANormalForm(e, m, &gv); } -TVM_REGISTER_API("relay._ir_pass.to_anf") +TVM_REGISTER_API("relay._ir_pass.to_a_normal_form") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = ToANF(args[0], args[1]); + *ret = ToANormalForm(args[0], args[1]); }); } // namespace relay diff --git a/src/relay/pass/to_graph_normal_form.cc b/src/relay/pass/to_graph_normal_form.cc new file mode 100644 index 000000000000..bc1630263e3f --- /dev/null +++ b/src/relay/pass/to_graph_normal_form.cc @@ -0,0 +1,66 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file to_gnf.cc + * + * \brief Turn A normal form into graph normal form. + */ +#include +#include +#include "let_list.h" + +namespace tvm { +namespace relay { + +class UseVarVisitor : public ExprVisitor { + public: + explicit UseVarVisitor(const Var& v) : v(v) { } + + static bool UseVar(const Var& v, const Expr& e) { + UseVarVisitor uv(v); + uv(e); + return uv.use_var; + } + + private: + bool use_var = false; + Var v; + + void VisitExpr_(const VarNode* vn) override { + use_var = use_var || (v == GetRef(vn)); + } +}; + +class GNF : public ExprMutator { + private: + std::unordered_map var_map_; + Expr VisitExpr_(const VarNode* vn) override { + Var v = GetRef(vn); + return var_map_.count(v) == 0 ? v : var_map_.at(v); + } + + static bool UseVar(const Var& v, const Expr& e) { + return UseVarVisitor::UseVar(v, e); + } + + static Expr WrapRec(const Var& var, const Expr& val) { + return UseVar(var, val) ? LetNode::make(var, val, var) : val; + } + + Expr VisitExpr_(const LetNode* ln) override { + var_map_.insert(std::pair(ln->var, VisitExpr(WrapRec(ln->var, ln->value)))); + return VisitExpr(ln->body); + } +}; + +Expr ToGraphNormalForm(const Expr& e) { + return GNF()(e); +} + +TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ToGraphNormalForm(args[0]); +}); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_to_anf.py b/tests/python/relay/test_to_a_normal_form.py similarity index 92% rename from tests/python/relay/test_to_anf.py rename to tests/python/relay/test_to_a_normal_form.py index e8c7995cfd8e..c15dc8ffc269 100644 --- a/tests/python/relay/test_to_anf.py +++ b/tests/python/relay/test_to_a_normal_form.py @@ -1,7 +1,7 @@ import numpy as np import tvm from tvm import relay -from tvm.relay.ir_pass import to_anf, alpha_equal, infer_type +from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type from tvm.relay import op, create_executor from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue from tvm.relay.prelude import Prelude @@ -21,7 +21,7 @@ def test_explicit_bound(): z = op.add(y, y) f = relay.Function([], op.add(z, z)) assert not "let" in f.astext() # assert the values are implicitly bounded - anf = to_anf(f) + anf = to_a_normal_form(f) assert "let" in anf.astext() # assert the values are explicitly bounded check_eval(f(), 8.0) check_eval(anf(), 8.0) @@ -35,7 +35,7 @@ def test_order(): x = relay.const(1) val = x + y * z check_eval(val, 7.0) - anf = infer_type(to_anf(val)) + anf = infer_type(to_a_normal_form(val)) a = relay.Var('a', relay.IncompleteType()) b = relay.Var('b', relay.IncompleteType()) c = relay.Var('c', relay.IncompleteType()) @@ -54,7 +54,7 @@ def test_order(): def test_if(): cond = relay.const(True) x = relay.If(cond, relay.const(2), relay.const(3)) - anf = infer_type(to_anf(x)) + anf = infer_type(to_a_normal_form(x)) a = relay.Var('a', relay.IncompleteType()) b = relay.Var('b', relay.IncompleteType()) c = relay.Var('c', relay.IncompleteType()) @@ -96,7 +96,7 @@ def test_recursion(): mod[f] = value check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) old_f = mod[f] - f = to_anf(f, mod=mod) + f = to_a_normal_form(f, mod=mod) check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) @@ -111,7 +111,7 @@ def test_ref(): body = relay.Let(iv, relay.RefRead(i), body) body = relay.Let(i, relay.RefCreate(relay.const(1)), body) check_eval(body, 3) - check_eval(to_anf(body), 3) + check_eval(to_a_normal_form(body), 3) # this is an example of using the adt value in python side @@ -135,7 +135,7 @@ def test_add(): intrp = create_executor(mod=mod, ctx=ctx, target="llvm") assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) assert count(intrp.evaluate(add(s(z()), s(z())))) == 2 - assert count(intrp.evaluate(to_anf(add(s(z()), s(z())), mod))) == 2 + assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2 assert "let" in mod[add].astext() if __name__ == '__main__': diff --git a/tests/python/relay/test_to_graph_normal_form.py b/tests/python/relay/test_to_graph_normal_form.py new file mode 100644 index 000000000000..ac86799b6b8c --- /dev/null +++ b/tests/python/relay/test_to_graph_normal_form.py @@ -0,0 +1,51 @@ +import numpy as np +import tvm +from tvm import relay +from tvm.relay.ir_pass import to_graph_normal_form, to_a_normal_form, alpha_equal +from tvm.relay import op, create_executor +from tvm.relay.backend.interpreter import Value, TupleValue + + +def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): + if mod is None: + mod = relay.Module() + + ctx = tvm.context("llvm", 0) + intrp = create_executor(mod=mod, ctx=ctx, target="llvm") + + result = intrp.evaluate(expr)(*args) + np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) + + +def test_implicit_share(): + x = relay.Var('x') + y = relay.Var('y') + z = relay.Var('z') + body = relay.Let(z, op.add(y, y), op.add(z, z)) + body = relay.Let(y, op.add(x, x), body) + f = relay.Function([], relay.Let(x, relay.const(1), body)) + g = to_graph_normal_form(f) + assert "let" in f.astext() + assert not "let" in g.astext() + check_eval(f, [], 8.0) + check_eval(g, [], 8.0) + + +def test_round_trip(): + x = relay.Var('x') + y = relay.Var('y') + z = relay.Var('z') + body = relay.Let(z, op.add(y, y), op.add(z, z)) + body = relay.Let(y, op.add(x, x), body) + f = relay.Function([], relay.Let(x, relay.const(1), body)) + g = to_graph_normal_form(f) + h = to_a_normal_form(g) + assert "let" in f.astext() + assert not "let" in g.astext() + check_eval(f, [], 8.0) + check_eval(g, [], 8.0) + check_eval(h, [], 8.0) + +if __name__ == '__main__': + test_implicit_share() + test_round_trip()