Skip to content

Commit

Permalink
save (#3033)
Browse files Browse the repository at this point in the history
save

save

save

upstream

lint

remove bad changes

fix build

save

save

please the ci god

Update src/relay/pass/partial_eval.cc

Co-Authored-By: Wei Chen <ipondering.weic@gmail.com>

save

fix test

ci is ANGRY

fix rebase problem

fix rebase

add test

save

save

comment
  • Loading branch information
MarisaKirisame authored and jroesch committed Jun 15, 2019
1 parent 50dd03c commit df88c41
Show file tree
Hide file tree
Showing 7 changed files with 624 additions and 171 deletions.
11 changes: 7 additions & 4 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,15 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
* For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1 in a` will be optimized into 1.
* As another example, `let a = 1 in a` will be optimized into 1,
* if the flag is turned on.
*
* \param e the expression to optimize.
* \param inline_once whether or not to inline binding used one.
*
* \return the optimized expression.
*/
TVM_DLL Expr DeadCodeElimination(const Expr& e);
TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false);

/*!
* \brief Fold constant expressions.
Expand Down Expand Up @@ -435,11 +437,12 @@ TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
* \param e the expression,
* \param e the expression
* \param mod the module
*
* \return the optimized expression.
*/
TVM_DLL Expr PartialEval(const Expr& e);
TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);

/*!
* \brief Bind the free variables to a Relay expression.
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,11 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
*
* As another example, `let a = 1 in a` will be optimized into 1.
*
* \param inline_once whether or not to inline binding used one.
*
* \return the pass.
*/
TVM_DLL Pass DeadCodeElimination();
TVM_DLL Pass DeadCodeElimination(bool inline_once = false);

/*!
* \brief Fold constant expressions.
Expand Down
77 changes: 42 additions & 35 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def well_formed(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
Expand Down Expand Up @@ -175,7 +175,7 @@ def free_vars(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
Expand All @@ -197,7 +197,7 @@ def bound_vars(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
Expand All @@ -213,7 +213,7 @@ def all_vars(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
Expand All @@ -229,9 +229,10 @@ def free_type_vars(expr, mod=None):
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module
Returns
Expand All @@ -248,9 +249,10 @@ def bound_type_vars(expr, mod=None):
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module
Returns
Expand All @@ -267,9 +269,9 @@ def all_type_vars(expr, mod=None):
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module
Returns
Expand All @@ -286,12 +288,12 @@ def simplify_inference(expr):
Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with some simplification
"""
Expand All @@ -304,48 +306,50 @@ def canonicalize_ops(expr):
Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression without bias_add
"""
return _ir_pass.canonicalize_ops(expr)


def dead_code_elimination(expr):
def dead_code_elimination(expr, inline_once=False):
""" Remove expressions which does not effect the program result (dead code).
Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression
inline_once : Optional[Bool]
Whether to inline binding that occur only once.
Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with dead code removed.
"""
return _ir_pass.dead_code_elimination(expr)
return _ir_pass.dead_code_elimination(expr, inline_once)


def alpha_equal(lhs, rhs):
"""Compare two Relay expr for structural equivalence (alpha equivalence).
Parameters
----------
lhs: tvm.relay.Expr
lhs : tvm.relay.Expr
One of the input Expression.
rhs: tvm.relay.Expr
rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
result: bool
result : bool
True iff lhs is alpha equal to rhs.
"""
return bool(_make._alpha_equal(lhs, rhs))
Expand All @@ -359,15 +363,15 @@ def graph_equal(lhs, rhs):
Parameters
----------
lhs: tvm.relay.Expr
lhs : tvm.relay.Expr
One of the input Expression.
rhs: tvm.relay.Expr
rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
result: bool
result : bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_make._graph_equal(lhs, rhs))
Expand All @@ -378,12 +382,12 @@ def structural_hash(value):
Parameters
----------
expr: tvm.relay.Expr or tvm.relay.Type
expr : Union[tvm.relay.Expr, tvm.relay.Type]
The expression to hash.
Returns
-------
result: int
result : int
The hash value
"""
if isinstance(value, Expr):
Expand Down Expand Up @@ -544,12 +548,12 @@ def to_a_normal_form(expr, mod=None):
expr : tvm.relay.Expr
The input expression.
mod: Optional[tvm.relay.Module]
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
expr: tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_a_normal_form(expr, mod)
Expand All @@ -563,7 +567,7 @@ def to_graph_normal_form(expr):
The input expression
Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression
"""
return _ir_pass.to_graph_normal_form(expr)
Expand Down Expand Up @@ -612,7 +616,7 @@ def get_total_mac_number(expr):
Returns
-------
ret : int64
result : int64
The number of MACs (multiply-accumulate) of a model
"""
return _ir_pass.GetTotalMacNumber(expr)
Expand All @@ -627,17 +631,17 @@ def eliminate_common_subexpr(expr, fskip=None):
expr : tvm.relay.Expr
The input expression.
fskip: function
fskip : function
The callback function that decides whether an expression should be skipped.
Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.eliminate_common_subexpr(expr, fskip)

def partial_evaluate(expr):
def partial_evaluate(expr, mod=None):
"""
Evaluate the static fragment of the code.
Expand All @@ -646,12 +650,15 @@ def partial_evaluate(expr):
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The global module
Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.partial_evaluate(expr)
return _ir_pass.partial_evaluate(expr, mod)

def unmatched_cases(match, mod=None):
"""
Expand Down
6 changes: 3 additions & 3 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ TVM_REGISTER_API("relay._make.Call")

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
});

Let LetNode::make(Var var, Expr value, Expr body) {
Expand Down Expand Up @@ -324,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)

TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
return temp->Realize();
return temp->Realize();
});

} // namespace relay
Expand Down
28 changes: 18 additions & 10 deletions src/relay/pass/dead_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ namespace relay {
// calculate the dependency graph from expression
class CalcDep : private ExprVisitor {
public:
static Expr Eliminate(const Expr& e) {
static Expr Eliminate(const Expr& e, bool inline_once) {
CalcDep cd;
cd.Calculate(e);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
return el(e);
}

Expand Down Expand Up @@ -117,15 +117,23 @@ class CalcDep : private ExprVisitor {
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
bool inline_once_;
explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map,
const VarSet& letrec_set) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { }
const VarSet& letrec_set,
bool inline_once) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { }
friend CalcDep;

bool HasLet(const Var& v) {
// TODO(@jroesch): MK fix me
return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
switch (use_map_[v]) {
case 0:
return false;
case 1:
return letrec_set_.count(v) > 0 || !inline_once_;
default:
return true;
}
}

Expr VisitExpr_(const VarNode* op) final {
Expand All @@ -144,19 +152,19 @@ class CalcDep : private ExprVisitor {
};
};

Expr DeadCodeElimination(const Expr& e) {
return CalcDep::Eliminate(e);
Expr DeadCodeElimination(const Expr& e, bool inline_once) {
return CalcDep::Eliminate(e, inline_once);
}

TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
.set_body_typed(DeadCodeElimination);

namespace transform {

Pass DeadCodeElimination() {
Pass DeadCodeElimination(bool inline_once) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(DeadCodeElimination(f));
return Downcast<Function>(DeadCodeElimination(f, inline_once));
};
return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
}
Expand Down
Loading

0 comments on commit df88c41

Please sign in to comment.