forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fcda217
commit 4a14e2a
Showing
6 changed files
with
169 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* | ||
* \file to_gnf.cc | ||
* | ||
* \brief Turn A normal form into graph normal form. | ||
*/ | ||
#include <tvm/relay/pass.h> | ||
#include <tvm/relay/expr_functor.h> | ||
#include "let_list.h" | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
class UseVarVisitor : public ExprVisitor { | ||
public: | ||
explicit UseVarVisitor(const Var& v) : v(v) { } | ||
|
||
static bool UseVar(const Var& v, const Expr& e) { | ||
UseVarVisitor uv(v); | ||
uv(e); | ||
return uv.use_var; | ||
} | ||
|
||
private: | ||
bool use_var = false; | ||
Var v; | ||
|
||
void VisitExpr_(const VarNode* vn) override { | ||
use_var = use_var || (v == GetRef<Var>(vn)); | ||
} | ||
}; | ||
|
||
class GNF : public ExprMutator { | ||
private: | ||
std::unordered_map<Var, Expr, NodeHash, NodeEqual> var_map_; | ||
Expr VisitExpr_(const VarNode* vn) override { | ||
Var v = GetRef<Var>(vn); | ||
return var_map_.count(v) == 0 ? v : var_map_.at(v); | ||
} | ||
|
||
static bool UseVar(const Var& v, const Expr& e) { | ||
return UseVarVisitor::UseVar(v, e); | ||
} | ||
|
||
static Expr WrapRec(const Var& var, const Expr& val) { | ||
return UseVar(var, val) ? LetNode::make(var, val, var) : val; | ||
} | ||
|
||
Expr VisitExpr_(const LetNode* ln) override { | ||
var_map_.insert(std::pair<Var, Expr>(ln->var, VisitExpr(WrapRec(ln->var, ln->value)))); | ||
return VisitExpr(ln->body); | ||
} | ||
}; | ||
|
||
Expr ToGraphNormalForm(const Expr& e) { | ||
return GNF()(e); | ||
} | ||
|
||
TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form") | ||
.set_body([](TVMArgs args, TVMRetValue* ret) { | ||
*ret = ToGraphNormalForm(args[0]); | ||
}); | ||
|
||
} // namespace relay | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import numpy as np | ||
import tvm | ||
from tvm import relay | ||
from tvm.relay.ir_pass import to_graph_normal_form, to_a_normal_form, alpha_equal | ||
from tvm.relay import op, create_executor | ||
from tvm.relay.backend.interpreter import Value, TupleValue | ||
|
||
|
||
def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): | ||
if mod is None: | ||
mod = relay.Module() | ||
|
||
ctx = tvm.context("llvm", 0) | ||
intrp = create_executor(mod=mod, ctx=ctx, target="llvm") | ||
|
||
result = intrp.evaluate(expr)(*args) | ||
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) | ||
|
||
|
||
def test_implicit_share(): | ||
x = relay.Var('x') | ||
y = relay.Var('y') | ||
z = relay.Var('z') | ||
body = relay.Let(z, op.add(y, y), op.add(z, z)) | ||
body = relay.Let(y, op.add(x, x), body) | ||
f = relay.Function([], relay.Let(x, relay.const(1), body)) | ||
g = to_graph_normal_form(f) | ||
assert "let" in f.astext() | ||
assert not "let" in g.astext() | ||
check_eval(f, [], 8.0) | ||
check_eval(g, [], 8.0) | ||
|
||
|
||
def test_round_trip(): | ||
x = relay.Var('x') | ||
y = relay.Var('y') | ||
z = relay.Var('z') | ||
body = relay.Let(z, op.add(y, y), op.add(z, z)) | ||
body = relay.Let(y, op.add(x, x), body) | ||
f = relay.Function([], relay.Let(x, relay.const(1), body)) | ||
g = to_graph_normal_form(f) | ||
h = to_a_normal_form(g) | ||
assert "let" in f.astext() | ||
assert not "let" in g.astext() | ||
check_eval(f, [], 8.0) | ||
check_eval(g, [], 8.0) | ||
check_eval(h, [], 8.0) | ||
|
||
if __name__ == '__main__': | ||
test_implicit_share() | ||
test_round_trip() |