Skip to content

Commit

Permalink
fix conv2d and conv2d alter op layout for x86
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon committed Jan 17, 2020
1 parent 69b171c commit 6b4cb34
Show file tree
Hide file tree
Showing 18 changed files with 290 additions and 141 deletions.
3 changes: 2 additions & 1 deletion include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ using FTVMStrategy = GenericFunc;
using FTVMAlterOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
const Array<Tensor>& tinfos)>;
const Array<Tensor>& tinfos,
const Type& out_type)>;

/*!
* \brief Convert the layout of operators or replace the
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def expr2graph(expr, target_ops, node_dict, node_list):
% op_name)
topi_funcs += OP2COMPUTE[op_name]
env.reset(topi_funcs)
# TODO(@kevinthesun, @icemelon9): Currently graph tuning pass relies on the fact
# that # autotvm tasks == # ops. But this won't be true after having relay op
# strategy. We need to find a solution to fix this.
with env:
_expr2graph_impl(expr, target_ops, node_dict, node_list)
task_pos = 0
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@
FallbackContext, clear_fallback_cache, ApplyGraphBest

from .topi_integration import register_topi_compute, register_topi_schedule, \
TaskExtractEnv, register_topi_compute2, register_topi_schedule2
TaskExtractEnv, register_topi_compute2, register_topi_schedule2, get_workload
from .relay_integration import extract_from_program, extract_from_multiple_program
11 changes: 5 additions & 6 deletions python/tvm/autotvm/task/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,13 +483,14 @@ def _query_inside(self, target, workload):
cfg : ConfigSpace
The specific configuration.
"""
print('=' * 80)
print('query graph dispatcher: %s, %s' % (target, workload))
if self._counter < len(self._records):
cfg = self._records[self._counter][0].config
wkl = self._records[self._counter][0].task.workload
if workload is not None:
assert wkl == workload
self._counter += 1
print(self._counter, cfg)
self.update(target, workload, cfg)
self.update(target, wkl, cfg)
cfg.workload = wkl
return cfg
key = (str(target), workload)
if key not in self._global_cfg_dict:
Expand All @@ -504,7 +505,5 @@ def _query_inside(self, target, workload):
return cfg

def update(self, target, workload, cfg):
print('-' * 80)
print('update %s %s -> %s' % (target, workload, cfg))
key = (str(target), workload)
self._global_cfg_dict[key] = cfg
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,

# create tasks for target
tasks = []
for task_name, args, _ in env.get_tasks():
for task_name, args in env.get_tasks():
try:
key = task_name_to_keys[task_name] if task_name in task_name_to_keys else 'direct'
tsk = create(task_name, args,
Expand Down
41 changes: 21 additions & 20 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def reset(self, wanted_topi_funcs):
if wanted_topi_funcs is not None:
self.wanted_topi_funcs = wanted_topi_funcs

def add_task(self, task_name, args, cond=None):
def add_task(self, task_name, args):
"""Add AutoTVM task
Parameters
Expand All @@ -301,9 +301,7 @@ def add_task(self, task_name, args, cond=None):
cond: SpecializedCondition
Specialized condition to enable the TOPI template.
"""
assert cond is None, \
"AutoTVM currently doesn't support tuning under specialized condition"
key = (task_name, serialize_args(args), None)
key = (task_name, serialize_args(args))
if self.allow_duplicate or key not in self.task_collection:
self.task_collection.append(key)

Expand Down Expand Up @@ -515,7 +513,7 @@ def wrapper(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
task_env = TaskExtractEnv.current
if task_env is not None and task_env.tracing:
task_env.add_task(task_name, args, current_specialization())
task_env.add_task(task_name, args)
workload = args_to_workload2(args, task_name)
tgt = _target.current_target()
cfg = DispatchContext.current.query(tgt, workload)
Expand Down Expand Up @@ -548,31 +546,34 @@ def wrapper(*args, **kwargs):
return _decorate(func)
return _decorate


def register_topi_schedule2(task_name, func=None):
def _decorate(topi_schedule):
@register_task_schedule(task_name)
def wrapper(outs, *args, **kwargs):
def traverse(tensors):
"""traverse all ops to find attached workload"""
for t in tensors:
op = t.op
if 'workload' in op.attrs:
return op.attrs['workload']
wkl = traverse(op.input_tensors)
if wkl:
return wkl
return None

outs = [outs] if isinstance(outs, tensor.Tensor) else outs
workload = traverse(outs)
workload = get_workload(outs)
if workload is None:
raise RuntimeError("Cannot find workload in attribute of this schedule")
workload = args_to_workload2(workload)
tgt = _target.current_target()
cfg = DispatchContext.current.query(tgt, workload)
return topi_schedule(cfg, outs, *args, **kwargs)

return wrapper
if func:
return _decorate(func)
return _decorate
return _decorate


def get_workload(outs):
def traverse(tensors):
"""traverse all ops to find attached workload"""
for t in tensors:
op = t.op
if 'workload' in op.attrs:
return op.attrs['workload']
wkl = traverse(op.input_tensors)
if wkl:
return wkl
return None
outs = [outs] if isinstance(outs, tensor.Tensor) else outs
return traverse(outs)
111 changes: 62 additions & 49 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ... import _api_internal
from ... import target as _target
from ..._ffi.function import register_func
from ...autotvm import task as _task
from ... import autotvm
from .. import expr as _expr
from .. import op as _op
from .. import ty as _ty
Expand Down Expand Up @@ -97,6 +97,60 @@ def get_shape(shape):
return ret


def get_valid_implements(op, attrs, inputs, out_type, target):
"""only use this function with concrete shapes"""
fstrategy = op.get_attr("FTVMStrategy")
assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name
with target:
strategy = fstrategy(attrs, inputs, out_type, target)
ret = []
for spec in strategy.specializations:
if spec.condition:
for clause in spec.condition.clauses:
clause = tvm.ir_pass.Simplify(clause)
if isinstance(clause, tvm.expr.IntImm) and int(clause):
ret.append(impl)
else:
for impl in spec.implements:
ret.append(impl)
return ret


def select_implement(op, attrs, inputs, out_type, target, use_autotvm=True):
"""only use this function with concrete shapes"""
all_impls = get_valid_implements(op, attrs, inputs, out_type, target)

best_plevel_impl = None
for impl in all_impls:
if best_plevel_impl is None or int(impl.plevel) > int(best_plevel_impl.plevel):
best_plevel_impl = impl
if not use_autotvm:
outs = best_plevel_impl.compute(attrs, inputs, out_type)
return best_plevel_impl, outs

outputs = {}
best_autotvm_impl = None
best_cfg = None
dispatch_ctx = autotvm.task.DispatchContext.current
for impl in all_impls:
outs = impl.compute(attrs, inputs, out_type)
outputs[impl] = outs
workload = autotvm.task.get_workload(outs)
if workload is None:
continue
workload = autotvm.task.args_to_workload2(workload)
cfg = dispatch_ctx.query(target, workload)
if cfg.cost is None:
# It's a fallback config
continue
if best_cfg is None or best_cfg.cost > cfg.cost:
best_autotvm_impl = impl
best_cfg = cfg
if best_autotvm_impl:
return best_autotvm_impl, outputs[best_autotvm_impl]
return best_plevel_impl, outputs[best_plevel_impl]


class ScheduleGetter(ExprVisitor):
"""Get the schedule given a fused Relay function"""

Expand Down Expand Up @@ -199,35 +253,18 @@ def visit_call(self, call):
outputs = [_api_internal._Tensor(copy_input.shape, copy_input.dtype,
None, 0)]
else:
fstrategy = op.get_attr("FTVMStrategy")
assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name
strategy = fstrategy(call.attrs, inputs, ret_type, self.target)
op_spec = None
# TODO(@icemelon9): current only use the default specialization (with no condition)
for spec in strategy.specializations:
if spec.condition is None:
op_spec = spec
break
assert op_spec is not None, \
"Cannot find default specialization for op %s" % op.name
assert len(op_spec.implements) > 0

is_dyn = call.checked_type.is_dynamic()
for arg in call.args:
is_dyn = is_dyn or arg.checked_type.is_dynamic()

if not is_dyn:
best_imp = self.get_best_implement_by_autotvm(
op_spec, call.attrs, inputs, ret_type)
if best_imp is None:
best_imp = self.get_best_implement_by_plevel(
op_spec, call.attrs, inputs, ret_type)
best_impl, outputs = select_implement(
op, call.attrs, inputs, ret_type, self.target)
else:
# for dynamic case, we just use the implementation with highest score
best_imp = self.get_best_implement_by_plevel(
op_spec, call.attrs, inputs, ret_type)
assert best_imp is not None
outputs = best_imp.compute(call.attrs, inputs, ret_type)
# TODO(@icemelon9): Allow tvm to generate multiple kernels for dynamic shapes
# for dynamic case, we currently use the implementation with highest plevel
best_impl, outputs = select_implement(
op, call.attrs, inputs, ret_type, self.target, use_autotvm=False)
op_pattern = op.get_attr("TOpPattern")
if op_pattern >= _op.OpPattern.COMM_REDUCE:
assert self.master_op is None or self.master_op_pattern < _op.OpPattern.COMM_REDUCE, \
Expand All @@ -237,7 +274,7 @@ def visit_call(self, call):
self.master_op = op
self.master_attrs = call.attrs
self.master_op_pattern = op_pattern
self.master_implement = best_imp
self.master_implement = best_impl
if len(outputs) > 1:
assert isinstance(call.checked_type, _ty.TupleType)
assert len(call.checked_type.fields) == len(outputs)
Expand Down Expand Up @@ -269,30 +306,6 @@ def visit_tuple_getitem(self, op):
assert op.index < tup.size()
return [tup[op.index]]

def get_best_implement_by_autotvm(self, op_spec, attrs, inputs, ret_type):
min_cost = None
best_imp = None
for imp in op_spec.implements:
outs = imp.compute(attrs, inputs, ret_type)
if 'workload' not in outs[0].op.attrs:
continue
workload = _task.args_to_workload2(outs[0].op.attrs['workload'])
cfg = _task.DispatchContext.current.query(self.target, workload)
if cfg.cost is None:
# This is fallback config
continue
if min_cost is None or min_cost > cfg.cost:
min_cost = cfg.cost
best_imp = imp
return best_imp

def get_best_implement_by_plevel(self, op_spec, attrs, inputs, ret_type):
best_imp = None
for imp in op_spec.implements:
if best_imp is None or int(imp.plevel) > int(best_imp.plevel):
best_imp = imp
return best_imp


@register_func("relay.backend.create_schedule")
def create_schedule(src_func, target):
Expand Down
5 changes: 2 additions & 3 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,9 @@ def _find_conv2d_op(op):
reg.register_strategy("nn.conv2d", strategy.conv2d_strategy)

@reg.register_alter_op_layout("nn.conv2d")
def alter_op_layout_conv2d(attrs, inputs, tinfos):
def alter_op_layout_conv2d(attrs, inputs, tinfos, out_type):
"""Alternate the layout of conv2d"""
from ... import op
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)

@reg.register_legalize("nn.conv2d")
def legalize_conv2d(attrs, inputs, types):
Expand Down
10 changes: 0 additions & 10 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,16 +347,6 @@ def _build(lowered_funcs):
_schedule_injective = None
_schedule_reduce = None

# def register_injective_schedule(target, schedule):
# def wrap_schedule(_, outs):
# return schedule(outs)
# _injective_schedule_map.append([target, wrap_schedule])
#
# def register_reduce_schedule(target, schedule):
# def wrap_schedule(_, outs):
# return schedule(outs)
# _reduce_schedule_map.append([target, wrap_schedule])

__DEBUG_COUNTER__ = 0

def debug(expr, debug_func=None):
Expand Down
Loading

0 comments on commit 6b4cb34

Please sign in to comment.