Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
save

save

save

upstream

lint

remove bad changes

fix build

save

save

please the ci god

Update src/relay/pass/partial_eval.cc

Co-Authored-By: Wei Chen <ipondering.weic@gmail.com>

save

fix test

ci is ANGRY

fix rebase problem

fix rebase

add test

save

save

comment
  • Loading branch information
MarisaKirisame committed Jun 15, 2019
1 parent b8fa8f6 commit 0c0d186
Show file tree
Hide file tree
Showing 7 changed files with 624 additions and 171 deletions.
11 changes: 7 additions & 4 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,15 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
* For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1 in a` will be optimized into 1.
* As another example, `let a = 1 in a` will be optimized into 1,
* if the flag is turned on.
*
* \param e the expression to optimize.
* \param inline_once whether or not to inline binding used one.
*
* \return the optimized expression.
*/
TVM_DLL Expr DeadCodeElimination(const Expr& e);
TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false);

/*!
* \brief Fold constant expressions.
Expand Down Expand Up @@ -435,11 +437,12 @@ TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
* \param e the expression,
* \param e the expression
* \param mod the module
*
* \return the optimized expression.
*/
TVM_DLL Expr PartialEval(const Expr& e);
TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);

/*!
* \brief Bind the free variables to a Relay expression.
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,11 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
*
* As another example, `let a = 1 in a` will be optimized into 1.
*
* \param inline_once whether or not to inline binding used one.
*
* \return the pass.
*/
TVM_DLL Pass DeadCodeElimination();
TVM_DLL Pass DeadCodeElimination(bool inline_once = false);

/*!
* \brief Fold constant expressions.
Expand Down
77 changes: 42 additions & 35 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def well_formed(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
Expand Down Expand Up @@ -175,7 +175,7 @@ def free_vars(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
Expand All @@ -197,7 +197,7 @@ def bound_vars(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
Expand All @@ -213,7 +213,7 @@ def all_vars(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
Expand All @@ -229,9 +229,10 @@ def free_type_vars(expr, mod=None):
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module
Returns
Expand All @@ -248,9 +249,10 @@ def bound_type_vars(expr, mod=None):
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module
Returns
Expand All @@ -267,9 +269,9 @@ def all_type_vars(expr, mod=None):
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module
Returns
Expand All @@ -286,12 +288,12 @@ def simplify_inference(expr):
Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with some simplification
"""
Expand All @@ -304,48 +306,50 @@ def canonicalize_ops(expr):
Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression without bias_add
"""
return _ir_pass.canonicalize_ops(expr)


def dead_code_elimination(expr):
def dead_code_elimination(expr, inline_once=False):
""" Remove expressions which does not effect the program result (dead code).
Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression
inline_once : Optional[Bool]
Whether to inline binding that occur only once.
Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with dead code removed.
"""
return _ir_pass.dead_code_elimination(expr)
return _ir_pass.dead_code_elimination(expr, inline_once)


def alpha_equal(lhs, rhs):
"""Compare two Relay expr for structural equivalence (alpha equivalence).
Parameters
----------
lhs: tvm.relay.Expr
lhs : tvm.relay.Expr
One of the input Expression.
rhs: tvm.relay.Expr
rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
result: bool
result : bool
True iff lhs is alpha equal to rhs.
"""
return bool(_make._alpha_equal(lhs, rhs))
Expand All @@ -359,15 +363,15 @@ def graph_equal(lhs, rhs):
Parameters
----------
lhs: tvm.relay.Expr
lhs : tvm.relay.Expr
One of the input Expression.
rhs: tvm.relay.Expr
rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
result: bool
result : bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_make._graph_equal(lhs, rhs))
Expand All @@ -378,12 +382,12 @@ def structural_hash(value):
Parameters
----------
expr: tvm.relay.Expr or tvm.relay.Type
expr : Union[tvm.relay.Expr, tvm.relay.Type]
The expression to hash.
Returns
-------
result: int
result : int
The hash value
"""
if isinstance(value, Expr):
Expand Down Expand Up @@ -544,12 +548,12 @@ def to_a_normal_form(expr, mod=None):
expr : tvm.relay.Expr
The input expression.
mod: Optional[tvm.relay.Module]
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
expr: tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_a_normal_form(expr, mod)
Expand All @@ -563,7 +567,7 @@ def to_graph_normal_form(expr):
The input expression
Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression
"""
return _ir_pass.to_graph_normal_form(expr)
Expand Down Expand Up @@ -612,7 +616,7 @@ def get_total_mac_number(expr):
Returns
-------
ret : int64
result : int64
The number of MACs (multiply-accumulate) of a model
"""
return _ir_pass.GetTotalMacNumber(expr)
Expand All @@ -627,17 +631,17 @@ def eliminate_common_subexpr(expr, fskip=None):
expr : tvm.relay.Expr
The input expression.
fskip: function
fskip : function
The callback function that decides whether an expression should be skipped.
Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.eliminate_common_subexpr(expr, fskip)

def partial_evaluate(expr):
def partial_evaluate(expr, mod=None):
"""
Evaluate the static fragment of the code.
Expand All @@ -646,12 +650,15 @@ def partial_evaluate(expr):
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The global module
Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.partial_evaluate(expr)
return _ir_pass.partial_evaluate(expr, mod)

def unmatched_cases(match, mod=None):
"""
Expand Down
6 changes: 3 additions & 3 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ TVM_REGISTER_API("relay._make.Call")

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
});

Let LetNode::make(Var var, Expr value, Expr body) {
Expand Down Expand Up @@ -324,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)

TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
return temp->Realize();
return temp->Realize();
});

} // namespace relay
Expand Down
28 changes: 18 additions & 10 deletions src/relay/pass/dead_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ namespace relay {
// calculate the dependency graph from expression
class CalcDep : private ExprVisitor {
public:
static Expr Eliminate(const Expr& e) {
static Expr Eliminate(const Expr& e, bool inline_once) {
CalcDep cd;
cd.Calculate(e);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
return el(e);
}

Expand Down Expand Up @@ -117,15 +117,23 @@ class CalcDep : private ExprVisitor {
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
bool inline_once_;
explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map,
const VarSet& letrec_set) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { }
const VarSet& letrec_set,
bool inline_once) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { }
friend CalcDep;

bool HasLet(const Var& v) {
// TODO(@jroesch): MK fix me
return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0));
switch (use_map_[v]) {
case 0:
return false;
case 1:
return letrec_set_.count(v) > 0 || !inline_once_;
default:
return true;
}
}

Expr VisitExpr_(const VarNode* op) final {
Expand All @@ -144,19 +152,19 @@ class CalcDep : private ExprVisitor {
};
};

Expr DeadCodeElimination(const Expr& e) {
return CalcDep::Eliminate(e);
Expr DeadCodeElimination(const Expr& e, bool inline_once) {
return CalcDep::Eliminate(e, inline_once);
}

TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
.set_body_typed(DeadCodeElimination);

namespace transform {

Pass DeadCodeElimination() {
Pass DeadCodeElimination(bool inline_once) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(DeadCodeElimination(f));
return Downcast<Function>(DeadCodeElimination(f, inline_once));
};
return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {});
}
Expand Down
Loading

0 comments on commit 0c0d186

Please sign in to comment.