Skip to content

Commit

Permalink
[BUILD] Allow inject custom pass via phase (#408)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Sep 1, 2017
1 parent f73c461 commit 0138997
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __exit__(self, ptype, value, trace):
assert self._old_scope
BuildConfig.current = self._old_scope


BuildConfig.current = BuildConfig()

def build_config(**kwargs):
Expand Down Expand Up @@ -102,7 +103,8 @@ def build_config(**kwargs):
Whether split the loop containing double buffer so
that the buffer fetching won't contain condition.
add_lower_pass: list of function(Stmt->Stmt), default=None
add_lower_pass: list of tuiple (phase, function(Stmt->Stmt)), default=None
phase contains an integer on which optimization pass we apply the pass.
Additional lowering passes to be applied before make_api.
Returns
Expand Down Expand Up @@ -193,11 +195,19 @@ def lower(sch,
"""
binds, arg_list = get_binds(args, binds)
cfg = BuildConfig.current
add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
lower_phase2 = [x[1] for x in add_lower_pass if x[0] > 1]
# normalize schedule first
sch = sch.normalize()
# Phase 0
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
for f in lower_phase0:
stmt = f(stmt)
# Phase 1
stmt = ir_pass.StorageFlatten(stmt, binds, 64)
stmt = ir_pass.CanonicalSimplify(stmt)
if not simple_mode:
Expand All @@ -211,13 +221,15 @@ def lower(sch,
cfg.auto_unroll_max_step,
cfg.auto_unroll_min_depth,
cfg.unroll_explicit)
if cfg.add_lower_pass:
for f in cfg.add_lower_pass:
stmt = f(stmt)
for f in lower_phase1:
stmt = f(stmt)
# Phase 2
stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.LowerStorageAccessInfo(stmt)
stmt = ir_pass.RemoveNoOp(stmt)
stmt = ir_pass.RewriteUnsafeSelect(stmt)
for f in lower_phase2:
stmt = f(stmt)
if simple_mode:
return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
Expand Down

0 comments on commit 0138997

Please sign in to comment.