Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY] Basic block normal form #6152

Merged
merged 33 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod);
*/
TVM_DLL bool ConstantCheck(const Expr& e);

/*!
* \brief Check whether an expression is in the basic block normal form.
*
* \param e the expression.
*
* \return whether the expression is in the basic block normal form.
*/
TVM_DLL bool BasicBlockNormalFormCheck(const Expr& e);

/*!
* \brief Check that each Var is only bound once.
*
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,21 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
*/
TVM_DLL Pass RewriteAnnotatedOps(int fallback_device);

/*!
* \brief Turn an expression to Basic Block Normal Form.
*
* We define a block as a group of expressions implied by the scope structure.
*
* Each graph node can only belong to a single block.
*
* For any value that is being used in multiple blocks, it has to be referred
* by a Var which is defined in a block, whose scope is the least common ancestor
* of blocks this value is used.
*
* \return The pass.
*/
TVM_DLL Pass ToBasicBlockNormalForm();

/*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,21 @@ def check_constant(expr):
"""
return _ffi_api.check_constant(expr)

def check_basic_block_normal_form(expr):
"""Check whether an expression is in the basic block form

Parameters
----------
expr : tvm.relay.Expr
The input expression

Returns
-------
result : bool
Whether the expression is in the basic block form.
"""
return _ffi_api.check_basic_block_normal_form(expr)


def free_vars(expr):
"""Get free Vars from expression expr in Post DFS order.
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,21 @@ def ToANormalForm():
"""
return _ffi_api.ToANormalForm()

def ToBasicBlockNormalForm():
"""Turn an expression to Basic Block Normal Form.
We define a block as a group of expressions implied by the scope structure.
Each graph node can only belong to a single block.
For any value that is being used in multiple blocks, it has to be referred
by a Var which is defined in a block, whose scope is the least common ancestor
of blocks this value is used.

Returns
-------
ret: tvm.transform.Pass
The registered pass that transforms an expression into Basic Block Normal Form.
"""
return _ffi_api.ToBasicBlockNormalForm()


def ToCPS(expr, mod=None):
"""
Expand Down
4 changes: 4 additions & 0 deletions src/relay/analysis/dependency_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(f)];
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
for (const auto& p : f->params) {
Depend(b, p);
}
Depend(b, f->body);
graph_.post_dfs_order.push_back(b);
}
Expand All @@ -145,6 +148,7 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(l)];
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, l->var);
Depend(b, l->value);
Depend(b, l->body);
graph_.post_dfs_order.push_back(b);
Expand Down
1 change: 1 addition & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ class RelayBuildModule : public runtime::ModuleNode {
Array<Pass> pass_seqs;
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
pass_seqs.push_back(transform::ToBasicBlockNormalForm());

// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());
Expand Down
1 change: 1 addition & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
Array<Pass> pass_seqs;
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
pass_seqs.push_back(transform::ToBasicBlockNormalForm());
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());

Expand Down
6 changes: 6 additions & 0 deletions src/relay/transforms/let_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class LetList {
return ret;
}

/*! \brief get the number of let bindings in the let list.
*
* \return the let list size.
*/
size_t size() const { return lets_.size(); }

/*! \brief generate an LetList and wrap the result automatically.
*
* \param f a function that generate the unwrapped Expr.
Expand Down
36 changes: 36 additions & 0 deletions src/relay/transforms/pass_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@

#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>

#include "../analysis/dependency_graph.h"
#include "let_list.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -184,6 +189,37 @@ struct TreeBranchNode : TreeNode<ConditionObjectPtr> {
~TreeBranchNode() {}
};

struct ScopeNode;
using Scope = std::shared_ptr<ScopeNode>;
using NodeScopeMap = std::unordered_map<DependencyGraph::Node*, Scope>;
using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;

/* Invariant: when parent is null level is 0
* Invariant: when parent is not null level is 1 + parent->level
*/
struct ScopeNode {
// the level of the scope
size_t level;
// the parent scope
Scope parent;
// the corresponding let list which holds all let bindings in the scope
std::shared_ptr<LetList> ll = std::make_shared<LetList>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use a more expressive variable name, maybe let_list?

explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) {}
ScopeNode() : level(0) {}
};

/*! \brief Calculate the scope of nodes in the dependency graph by least common ancestor.
*
* \param dg the input dependency graph
* \param expr_scope the output node -> scope mapping for all nodes.
* \param lifted_exprs the output set of expressions whose scope is lifted due to dependency
*/
std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg);

/*! \brief find the least common ancestor of lhs scope and rhs scope.
*/
Scope LCA(Scope lhs, Scope rhs);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_TRANSFORMS_PASS_UTIL_H_
48 changes: 25 additions & 23 deletions src/relay/transforms/to_a_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,6 @@
namespace tvm {
namespace relay {

struct ScopeNode;
using Scope = std::shared_ptr<ScopeNode>;

/* Invariant: when parent is null level is 0
*
* Invariant: when parent is not null level is 1 + parent->level
*/
struct ScopeNode {
size_t level;
Scope parent;
std::shared_ptr<LetList> ll = std::make_shared<LetList>();
explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) {}
ScopeNode() : level(0) {}
};

Scope ChildScope(const Scope& s) { return std::make_shared<ScopeNode>(s); }

Scope LCA(Scope lhs, Scope rhs) {
while (lhs != rhs) {
if (lhs->level > rhs->level) {
Expand All @@ -67,10 +50,16 @@ Scope LCA(Scope lhs, Scope rhs) {
return lhs;
}

std::unordered_map<DependencyGraph::Node*, Scope> CalcScope(const DependencyGraph& dg) {
std::unordered_map<DependencyGraph::Node*, Scope> expr_scope;
std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg) {
NodeScopeMap expr_scope;
ExprSet lifted_exprs;
std::unordered_map<DependencyGraph::Node*, Expr> node_to_expr;
for (auto expr_node : dg.expr_node) {
node_to_expr[expr_node.second] = expr_node.first;
}
bool global_scope_used = false;
Scope global_scope = std::make_shared<ScopeNode>();

for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) {
DependencyGraph::Node* n = *it;
auto iit = n->parents.head;
Expand All @@ -81,15 +70,28 @@ std::unordered_map<DependencyGraph::Node*, Scope> CalcScope(const DependencyGrap
global_scope_used = true;
} else {
s = expr_scope.at(iit->value);
const auto original_s = s;
iit = iit->next;
for (; iit != nullptr; iit = iit->next) {
s = LCA(s, expr_scope.at(iit->value));
}
if (s != original_s && node_to_expr.find(n) != node_to_expr.end()) {
// filter out exprs whose scope do not matter
Expr expr = node_to_expr[n];
if (!expr.as<OpNode>()) {
lifted_exprs.insert(expr);
}
}
}
if (n->new_scope) {
auto child_scope = std::make_shared<ScopeNode>(s);
expr_scope.insert({n, child_scope});
} else {
expr_scope.insert({n, s});
}
expr_scope.insert({n, n->new_scope ? ChildScope(s) : s});
}
CHECK(global_scope_used);
return expr_scope;
return std::make_pair(expr_scope, lifted_exprs);
}

/* Special care is needed to handle local recursion.
Expand Down Expand Up @@ -269,8 +271,8 @@ Expr ToANormalFormAux(const Expr& e) {
* Every scope additionally contain a LetList which collect all value of that scope.
* We do an additional pass to fill all the LetList and we are done.
*/
std::unordered_map<DependencyGraph::Node*, Scope> node_scope = CalcScope(dg);
return Fill::ToANormalForm(e, dg, &node_scope);
std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
return Fill::ToANormalForm(e, dg, &scopes.first);
}

IRModule ToANormalForm(const IRModule& m) {
Expand Down
Loading