Skip to content

Commit

Permalink
[PASS] Canonical form simplify (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Feb 7, 2017
1 parent 2bcf3f2 commit 8837798
Show file tree
Hide file tree
Showing 11 changed files with 614 additions and 11 deletions.
7 changes: 7 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ bool HasSideEffect(const Expr& e);
*/
Stmt ConvertSSA(Stmt stmt);

/*!
* \brief Simplify by applying canonical form.
* \param stmt The statement to be canonically simplifed.
* \return Canonicalized statement.
*/
Stmt CanonicalSimplify(Stmt stmt);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def build(sch,
target,
name="default_function",
binds=None,
record_codes=None):
record_codes=None,
max_auto_unroll_step=8):
"""Build a function with arguments as signiture.
Parameters
Expand All @@ -38,6 +39,9 @@ def build(sch,
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
max_auto_unroll_step: int
Maximum step to perform automatic unrolling
Returns
-------
f : Function, or pair of functions
Expand All @@ -64,6 +68,8 @@ def build(sch,
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list))
fsplits = ir_pass.SplitHostDevice(fapi)
Expand Down
1 change: 1 addition & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ TVM_REGISTER_API(_pass_PostOrderVisit)

REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(CanonicalSimplify);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS2(UnrollLoop);
Expand Down
Loading

0 comments on commit 8837798

Please sign in to comment.