From 6b4cb3436fedf0a256fdf71c829fee23c7c21b9b Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Fri, 17 Jan 2020 11:02:52 -0800 Subject: [PATCH] fix conv2d and conv2d alter op layout for x86 --- include/tvm/relay/op_attr_types.h | 3 +- .../graph_tuner/utils/traverse_graph.py | 3 + python/tvm/autotvm/task/__init__.py | 2 +- python/tvm/autotvm/task/dispatcher.py | 11 +- python/tvm/autotvm/task/relay_integration.py | 2 +- python/tvm/autotvm/task/topi_integration.py | 41 ++--- python/tvm/relay/backend/compile_engine.py | 111 ++++++++------ python/tvm/relay/op/nn/_nn.py | 5 +- python/tvm/relay/op/op.py | 10 -- python/tvm/relay/op/strategy/generic.py | 36 +++-- python/tvm/relay/op/strategy/x86.py | 28 +++- src/relay/pass/alter_op_layout.cc | 5 +- topi/python/topi/generic/injective.py | 2 +- topi/python/topi/nn/conv2d.py | 2 +- topi/python/topi/x86/conv2d.py | 2 +- topi/python/topi/x86/conv2d_alter_op.py | 140 ++++++++++++++++-- topi/python/topi/x86/depthwise_conv2d.py | 8 +- tutorials/autotvm/tune_relay_x86.py | 20 +-- 18 files changed, 290 insertions(+), 141 deletions(-) diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 141dd03a9bef8..2391367269e9c 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -140,7 +140,8 @@ using FTVMStrategy = GenericFunc; using FTVMAlterOpLayout = runtime::TypedPackedFunc< Expr(const Attrs& attrs, const Array& args, - const Array& tinfos)>; + const Array& tinfos, + const Type& out_type)>; /*! * \brief Convert the layout of operators or replace the diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index 19c319361ce5d..324472bbd2b2e 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -65,6 +65,9 @@ def expr2graph(expr, target_ops, node_dict, node_list): % op_name) topi_funcs += OP2COMPUTE[op_name] env.reset(topi_funcs) + # TODO(@kevinthesun, @icemelon9): Currently graph tuning pass relies on the fact + # that # autotvm tasks == # ops. But this won't be true after having relay op + # strategy. We need to find a solution to fix this. with env: _expr2graph_impl(expr, target_ops, node_dict, node_list) task_pos = 0 diff --git a/python/tvm/autotvm/task/__init__.py b/python/tvm/autotvm/task/__init__.py index 1339ec4397c35..d3001059ba40b 100644 --- a/python/tvm/autotvm/task/__init__.py +++ b/python/tvm/autotvm/task/__init__.py @@ -30,5 +30,5 @@ FallbackContext, clear_fallback_cache, ApplyGraphBest from .topi_integration import register_topi_compute, register_topi_schedule, \ - TaskExtractEnv, register_topi_compute2, register_topi_schedule2 + TaskExtractEnv, register_topi_compute2, register_topi_schedule2, get_workload from .relay_integration import extract_from_program, extract_from_multiple_program diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py index 1bad94a23aea9..98e32874baee8 100644 --- a/python/tvm/autotvm/task/dispatcher.py +++ b/python/tvm/autotvm/task/dispatcher.py @@ -483,13 +483,14 @@ def _query_inside(self, target, workload): cfg : ConfigSpace The specific configuration. """ - print('=' * 80) - print('query graph dispatcher: %s, %s' % (target, workload)) if self._counter < len(self._records): cfg = self._records[self._counter][0].config + wkl = self._records[self._counter][0].task.workload + if workload is not None: + assert wkl == workload self._counter += 1 - print(self._counter, cfg) - self.update(target, workload, cfg) + self.update(target, wkl, cfg) + cfg.workload = wkl return cfg key = (str(target), workload) if key not in self._global_cfg_dict: @@ -504,7 +505,5 @@ def _query_inside(self, target, workload): return cfg def update(self, target, workload, cfg): - print('-' * 80) - print('update %s %s -> %s' % (target, workload, cfg)) key = (str(target), workload) self._global_cfg_dict[key] = cfg diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 133154da40377..55763afcf7bcc 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -171,7 +171,7 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None, # create tasks for target tasks = [] - for task_name, args, _ in env.get_tasks(): + for task_name, args in env.get_tasks(): try: key = task_name_to_keys[task_name] if task_name in task_name_to_keys else 'direct' tsk = create(task_name, args, diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index 5fe19a7838f28..816c65091051c 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -287,7 +287,7 @@ def reset(self, wanted_topi_funcs): if wanted_topi_funcs is not None: self.wanted_topi_funcs = wanted_topi_funcs - def add_task(self, task_name, args, cond=None): + def add_task(self, task_name, args): """Add AutoTVM task Parameters @@ -301,9 +301,7 @@ def add_task(self, task_name, args, cond=None): cond: SpecializedCondition Specialized condition to enable the TOPI template. """ - assert cond is None, \ - "AutoTVM currently doesn't support tuning under specialized condition" - key = (task_name, serialize_args(args), None) + key = (task_name, serialize_args(args)) if self.allow_duplicate or key not in self.task_collection: self.task_collection.append(key) @@ -515,7 +513,7 @@ def wrapper(*args, **kwargs): assert not kwargs, "Do not support kwargs in template function call" task_env = TaskExtractEnv.current if task_env is not None and task_env.tracing: - task_env.add_task(task_name, args, current_specialization()) + task_env.add_task(task_name, args) workload = args_to_workload2(args, task_name) tgt = _target.current_target() cfg = DispatchContext.current.query(tgt, workload) @@ -548,31 +546,34 @@ def wrapper(*args, **kwargs): return _decorate(func) return _decorate + def register_topi_schedule2(task_name, func=None): def _decorate(topi_schedule): @register_task_schedule(task_name) def wrapper(outs, *args, **kwargs): - def traverse(tensors): - """traverse all ops to find attached workload""" - for t in tensors: - op = t.op - if 'workload' in op.attrs: - return op.attrs['workload'] - wkl = traverse(op.input_tensors) - if wkl: - return wkl - return None - - outs = [outs] if isinstance(outs, tensor.Tensor) else outs - workload = traverse(outs) + workload = get_workload(outs) if workload is None: raise RuntimeError("Cannot find workload in attribute of this schedule") workload = args_to_workload2(workload) tgt = _target.current_target() cfg = DispatchContext.current.query(tgt, workload) return topi_schedule(cfg, outs, *args, **kwargs) - return wrapper if func: return _decorate(func) - return _decorate \ No newline at end of file + return _decorate + + +def get_workload(outs): + def traverse(tensors): + """traverse all ops to find attached workload""" + for t in tensors: + op = t.op + if 'workload' in op.attrs: + return op.attrs['workload'] + wkl = traverse(op.input_tensors) + if wkl: + return wkl + return None + outs = [outs] if isinstance(outs, tensor.Tensor) else outs + return traverse(outs) diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 19b2c0eaebbb1..9b18d09cec478 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -25,7 +25,7 @@ from ... import _api_internal from ... import target as _target from ..._ffi.function import register_func -from ...autotvm import task as _task +from ... import autotvm from .. import expr as _expr from .. import op as _op from .. import ty as _ty @@ -97,6 +97,60 @@ def get_shape(shape): return ret +def get_valid_implements(op, attrs, inputs, out_type, target): + """only use this function with concrete shapes""" + fstrategy = op.get_attr("FTVMStrategy") + assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name + with target: + strategy = fstrategy(attrs, inputs, out_type, target) + ret = [] + for spec in strategy.specializations: + if spec.condition: + for clause in spec.condition.clauses: + clause = tvm.ir_pass.Simplify(clause) + if isinstance(clause, tvm.expr.IntImm) and int(clause): + ret.append(impl) + else: + for impl in spec.implements: + ret.append(impl) + return ret + + +def select_implement(op, attrs, inputs, out_type, target, use_autotvm=True): + """only use this function with concrete shapes""" + all_impls = get_valid_implements(op, attrs, inputs, out_type, target) + + best_plevel_impl = None + for impl in all_impls: + if best_plevel_impl is None or int(impl.plevel) > int(best_plevel_impl.plevel): + best_plevel_impl = impl + if not use_autotvm: + outs = best_plevel_impl.compute(attrs, inputs, out_type) + return best_plevel_impl, outs + + outputs = {} + best_autotvm_impl = None + best_cfg = None + dispatch_ctx = autotvm.task.DispatchContext.current + for impl in all_impls: + outs = impl.compute(attrs, inputs, out_type) + outputs[impl] = outs + workload = autotvm.task.get_workload(outs) + if workload is None: + continue + workload = autotvm.task.args_to_workload2(workload) + cfg = dispatch_ctx.query(target, workload) + if cfg.cost is None: + # It's a fallback config + continue + if best_cfg is None or best_cfg.cost > cfg.cost: + best_autotvm_impl = impl + best_cfg = cfg + if best_autotvm_impl: + return best_autotvm_impl, outputs[best_autotvm_impl] + return best_plevel_impl, outputs[best_plevel_impl] + + class ScheduleGetter(ExprVisitor): """Get the schedule given a fused Relay function""" @@ -199,35 +253,18 @@ def visit_call(self, call): outputs = [_api_internal._Tensor(copy_input.shape, copy_input.dtype, None, 0)] else: - fstrategy = op.get_attr("FTVMStrategy") - assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name - strategy = fstrategy(call.attrs, inputs, ret_type, self.target) - op_spec = None - # TODO(@icemelon9): current only use the default specialization (with no condition) - for spec in strategy.specializations: - if spec.condition is None: - op_spec = spec - break - assert op_spec is not None, \ - "Cannot find default specialization for op %s" % op.name - assert len(op_spec.implements) > 0 - is_dyn = call.checked_type.is_dynamic() for arg in call.args: is_dyn = is_dyn or arg.checked_type.is_dynamic() if not is_dyn: - best_imp = self.get_best_implement_by_autotvm( - op_spec, call.attrs, inputs, ret_type) - if best_imp is None: - best_imp = self.get_best_implement_by_plevel( - op_spec, call.attrs, inputs, ret_type) + best_impl, outputs = select_implement( + op, call.attrs, inputs, ret_type, self.target) else: - # for dynamic case, we just use the implementation with highest score - best_imp = self.get_best_implement_by_plevel( - op_spec, call.attrs, inputs, ret_type) - assert best_imp is not None - outputs = best_imp.compute(call.attrs, inputs, ret_type) + # TODO(@icemelon9): Allow tvm to generate multiple kernels for dynamic shapes + # for dynamic case, we currently use the implementation with highest plevel + best_impl, outputs = select_implement( + op, call.attrs, inputs, ret_type, self.target, use_autotvm=False) op_pattern = op.get_attr("TOpPattern") if op_pattern >= _op.OpPattern.COMM_REDUCE: assert self.master_op is None or self.master_op_pattern < _op.OpPattern.COMM_REDUCE, \ @@ -237,7 +274,7 @@ def visit_call(self, call): self.master_op = op self.master_attrs = call.attrs self.master_op_pattern = op_pattern - self.master_implement = best_imp + self.master_implement = best_impl if len(outputs) > 1: assert isinstance(call.checked_type, _ty.TupleType) assert len(call.checked_type.fields) == len(outputs) @@ -269,30 +306,6 @@ def visit_tuple_getitem(self, op): assert op.index < tup.size() return [tup[op.index]] - def get_best_implement_by_autotvm(self, op_spec, attrs, inputs, ret_type): - min_cost = None - best_imp = None - for imp in op_spec.implements: - outs = imp.compute(attrs, inputs, ret_type) - if 'workload' not in outs[0].op.attrs: - continue - workload = _task.args_to_workload2(outs[0].op.attrs['workload']) - cfg = _task.DispatchContext.current.query(self.target, workload) - if cfg.cost is None: - # This is fallback config - continue - if min_cost is None or min_cost > cfg.cost: - min_cost = cfg.cost - best_imp = imp - return best_imp - - def get_best_implement_by_plevel(self, op_spec, attrs, inputs, ret_type): - best_imp = None - for imp in op_spec.implements: - if best_imp is None or int(imp.plevel) > int(best_imp.plevel): - best_imp = imp - return best_imp - @register_func("relay.backend.create_schedule") def create_schedule(src_func, target): diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 9c82d105cd918..f0586266cc70a 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -183,10 +183,9 @@ def _find_conv2d_op(op): reg.register_strategy("nn.conv2d", strategy.conv2d_strategy) @reg.register_alter_op_layout("nn.conv2d") -def alter_op_layout_conv2d(attrs, inputs, tinfos): +def alter_op_layout_conv2d(attrs, inputs, tinfos, out_type): """Alternate the layout of conv2d""" - from ... import op - return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op) + return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type) @reg.register_legalize("nn.conv2d") def legalize_conv2d(attrs, inputs, types): diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 56467f8fea6b7..1ed25f0346c69 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -347,16 +347,6 @@ def _build(lowered_funcs): _schedule_injective = None _schedule_reduce = None -# def register_injective_schedule(target, schedule): -# def wrap_schedule(_, outs): -# return schedule(outs) -# _injective_schedule_map.append([target, wrap_schedule]) -# -# def register_reduce_schedule(target, schedule): -# def wrap_schedule(_, outs): -# return schedule(outs) -# _reduce_schedule_map.append([target, wrap_schedule]) - __DEBUG_COUNTER__ = 0 def debug(expr, debug_func=None): diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 4d0330f4f1bf8..d715f634ad5fc 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -18,6 +18,7 @@ # pylint: disable=invalid-name,unused-argument from __future__ import absolute_import +import re import topi from topi.util import get_const_int, get_const_float, get_const_tuple, get_float_tuple from .. import op as _op @@ -30,19 +31,19 @@ def wrapper(attrs, outs, target): return topi_schedule(outs) return wrapper -def get_conv2d_out_depth(kernel, kernel_layout): - weight_shape = get_const_tuple(kernel.shape) - # NHWC layout - if kernel_layout.startswith("HW"): - return weight_shape[2] * weight_shape[3] - # NCHW layout. - # in ARM CPU contrib_spatial_pack schedule, we will prepack weight layout - if len(weight_shape) == 4: - return weight_shape[0] * weight_shape[1] +def get_conv2d_data_channels(data, data_layout): + data_shape = get_const_tuple(data.shape) + if len(data_shape) == 4: + idx = data_layout.find("C") + assert idx >= 0, "Invalid conv2d data layout {}".format(data_layout) + return data_shape[idx] + elif re.match("NCHW\d*c", data_layout): + # NCHW[8]c + return data_shape[1] * data_shape[4] else: - assert len(weight_shape) == 5 - C, M, _, _, VC = weight_shape - return C * VC * M + raise ValueError("Unknown conv2d data layout {}".format(data_layout)) + return 0 + @generic_func def schedule_injective(attrs, outs, target): @@ -154,30 +155,35 @@ def conv2d_strategy(attrs, inputs, out_type, target): if groups == 1: if layout == "NCHW" or layout == "NCHW4c": + assert kernel_layout == "OIHW" strategy.add_implement( wrap_compute_conv2d(topi.nn.conv2d_nchw), wrap_topi_schedule(topi.generic.schedule_conv2d_nchw)) elif layout == "NHWC": + assert kernel_layout == "HWIO" strategy.add_implement( wrap_compute_conv2d(topi.nn.conv2d_nhwc), wrap_topi_schedule(topi.generic.schedule_conv2d_nhwc)) elif layout == "HWCN": + assert kernel_layout == "HWIO" strategy.add_implement( wrap_compute_conv2d(topi.nn.conv2d_hwcn), wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn)) else: raise RuntimeError("Unsupported conv2d layout {}".format(layout)) else: - if layout == "NCHW" and get_conv2d_out_depth(inputs[1], kernel_layout) == groups: + is_depthwise = get_conv2d_data_channels(inputs[0], layout) == groups + if layout == "NCHW" and is_depthwise: + assert kernel_layout == "OIHW" strategy.add_implement( wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw)) - elif layout == "NHWC" and kernel_layout == "HWOI" \ - and get_conv2d_out_depth(inputs[1], kernel_layout) == groups: + elif layout == "NHWC" and kernel_layout == "HWOI" and is_depthwise: strategy.add_implement( wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc)) elif layout in ['NCHW', 'NCHW4c']: + assert kernel_layout == "OIHW" strategy.add_implement( wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True), wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw)) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 4bc82b71bd4e7..72ba586ed9922 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -21,7 +21,6 @@ import logging import topi -from topi.util import get_const_int, get_const_float, get_const_tuple, get_float_tuple from .generic import * from .. import op as _op from ....schedule import SpecializedCondition @@ -78,6 +77,7 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): if groups == 1: if layout == "NCHW": + assert kernel_layout == "OIHW" if topi.x86.is_int8_hw_support(data.dtype, kernel.dtype): strategy.add_implement( wrap_compute_conv2d(topi.x86.conv2d_nchw_int8), @@ -87,11 +87,13 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.x86.conv2d_nchw), wrap_topi_schedule(topi.x86.schedule_conv2d_nchw)) elif layout == "NHWC": + assert kernel_layout == "HWIO" logger.warning("For x86 target, NCHW layout is recommended for conv2d.") strategy.add_implement( wrap_compute_conv2d(topi.nn.conv2d_nhwc), wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc)) elif layout == "HWCN": + assert kernel_layout == "HWIO" logger.warning("For x86 target, NCHW layout is recommended for conv2d.") strategy.add_implement( wrap_compute_conv2d(topi.nn.conv2d_hwcn), @@ -99,17 +101,27 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): else: raise RuntimeError("Unsupported conv2d layout {} for x86 target".format(layout)) else: - if layout == "NCHW" and get_conv2d_out_depth(inputs[1], kernel_layout) == groups: - strategy.add_implement( - wrap_compute_conv2d(topi.x86.depthwise_conv2d_nchw), - wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_nchw)) - elif layout == "NHWC" and kernel_layout == "HWOI" \ - and get_conv2d_out_depth(inputs[1], kernel_layout) == groups: + is_depthwise = get_conv2d_data_channels(inputs[0], layout) == groups + if layout == "NCHW" and is_depthwise: + assert kernel_layout == "OIHW" + channel_multiplier = get_const_tuple(inputs[1].shape)[1] + if channel_multiplier == 1: + strategy.add_implement( + wrap_compute_conv2d(topi.x86.depthwise_conv2d_nchw), + wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_nchw)) + else: + logger.warning("For x86 target, depthwise_conv2d with channel " + "multiplier greater than 1 is not optimized") + strategy.add_implement( + wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw), + wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw)) + elif layout == "NHWC" and kernel_layout == "HWOI" and is_depthwise: logger.warning("For x86 target, NCHW layout is recommended for depthwise conv2d.") strategy.add_implement( wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc)) elif layout in ['NCHW', 'NCHW4c']: + assert kernel_layout == "OIHW" logger.warning("Group conv2d is not optimized for x86 target.") strategy.add_implement( wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True), @@ -179,7 +191,7 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target): def dense_strategy_cpu(attrs, inputs, out_type, target): """dense x86 strategy""" strategy = _op.OpStrategy() - m, k = inputs[0].shape + _, k = inputs[0].shape strategy.add_implement(wrap_compute_dense(topi.x86.dense_nopack), wrap_topi_schedule(topi.x86.schedule_dense_nopack), 10) diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index b027e5edb571b..6629ea8cdceef 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -83,7 +83,10 @@ class AlterTransformMemorizer : public TransformMemorizer { auto ttype = expr->type_as(); tinfos.push_back(tvm::placeholder(ttype->shape, ttype->dtype)); } - Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos); + // TODO(@kevinthesun, @icemelon9): This won't work if inputs/outputs are dynamic shapes. + // Probably we need to disable the AlterOpLayout when compiling dynamic models. + Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos, + ref_call->checked_type()); if (altered_value.defined()) { new_e = altered_value; modified = true; diff --git a/topi/python/topi/generic/injective.py b/topi/python/topi/generic/injective.py index f60fb9c36424f..de8da4ea9bc3b 100644 --- a/topi/python/topi/generic/injective.py +++ b/topi/python/topi/generic/injective.py @@ -35,7 +35,7 @@ def schedule_injective_from_existing(sch, out): sch: Schedule The updated schedule. """ - sch[out].fuse(sch[out].op.axis) + sch[out].fuse(*sch[out].op.axis) return sch def schedule_injective(outs): diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index b1976383dbd01..150ada9afba9e 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -95,7 +95,7 @@ def conv2d_legalize(attrs, inputs, types): @tvm.target.generic_func -def conv2d_alter_layout(attrs, inputs, tinfos, F): +def conv2d_alter_layout(attrs, inputs, tinfos, out_type): """Change Conv2D layout. Parameters diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 0a3d970b42184..82dadca7dceb8 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -57,7 +57,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth @conv2d_infer_layout.register("cpu") def _conv2d_infer_layout(workload, cfg): - _, data, kernel, strides, padding, dilation, layout, dtype = workload + _, data, kernel, strides, padding, dilation, layout, _, dtype = workload batch_size, in_channel, in_height, in_width = data[:-1] out_channel, _, k_height, k_width = kernel[:-1] idxdiv = tvm.indexdiv diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index 2654b51b80d43..0e39ac7fefe11 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -23,22 +23,143 @@ from tvm import relay from tvm import autotvm from .conv2d import _get_default_config -from .conv2d import conv2d_NCHWc from .conv2d_int8 import is_int8_hw_support, _get_default_config_int8 -from ..util import get_const_tuple, get_shape -from ..nn import conv2d_legalize -from ..nn import conv2d_alter_layout -#from ..nn.conv2d import conv2d, conv2d_NCHWc_int8, conv2d_alter_layout -#from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw +from ..util import get_const_tuple +from ..nn import conv2d_legalize, conv2d_alter_layout from ..nn.util import get_pad_tuple logger = logging.getLogger('topi') @conv2d_alter_layout.register("cpu") -def _alter_conv2d_layout(attrs, inputs, tinfo, F): +def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): # Parse the attributes. - # TODO(@icemelon9): fix this + target = tvm.target.current_target(allow_none=False) + dispatch_ctx = autotvm.task.DispatchContext.current + if isinstance(dispatch_ctx, autotvm.task.ApplyGraphBest): + cfg = dispatch_ctx.query(target, None) + workload = cfg.workload + else: + _, outs = relay.backend.compile_engine.select_implement( + relay.op.nn.conv2d, attrs, tinfos, out_type, target) + workload = autotvm.task.get_workload(outs) + if workload is None: + # The best implementation is not an AutoTVM template, + # we then assume it's not necessary to alter this op. + return None + cfg = dispatch_ctx.query(target, workload) + + topi_tmpl = workload[0] + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + data_layout = attrs["data_layout"] + kernel_layout = attrs["kernel_layout"] + data_tensor, kernel_tensor = tinfos + data_dtype = data_tensor.dtype + kernel_dtype = kernel_tensor.dtype + out_dtype = out_type.dtype + new_attrs = {k : attrs[k] for k in attrs.keys()} + + if topi_tmpl == "conv2d_NCHWc.x86": + # we only convert conv2d_NCHW to conv2d_NCHWc for x86 + assert data_layout == "NCHW" and kernel_layout == "OIHW" + if cfg.is_fallback: + _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding, + out_dtype, False, data_layout) + batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) + out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape) + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + + # update new attrs + new_attrs['channels'] = out_channel + new_attrs['data_layout'] = 'NCHW%dc' % ic_bn + # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) + new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) + new_attrs['out_layout'] = 'NCHW%dc' % oc_bn + + # Store altered operator's config + new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), + dtype=data_dtype) + new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, + kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype) + new_workload = autotvm.task.args_to_workload2( + [new_data, new_kernel, strides, padding, dilation, new_attrs["data_layout"], + new_attrs["out_layout"], out_dtype], topi_tmpl) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs) + elif topi_tmpl == "conv2d_NCHWc_int8.x86": + # TODO(@icemelon9, @anijain2305): Need to support data layout NHWC with kernel layout HWIO + assert data_layout == "NCHW" and kernel_layout == "OIHW" + if cfg.is_fallback: + _get_default_config_int8(cfg, data_tensor, kernel_tensor, strides, padding, + out_dtype, False, data_layout) + + batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) + out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape) + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + n_elems = 4 + + # convert kernel data layout from 4D to 7D + data_expr, kernel_expr = inputs + kernel_IHWO = relay.transpose(kernel_expr, axes=(1, 2, 3, 0)) + kernel_IHWOo = relay.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel//oc_bn, oc_bn)) + kernel_OHWoI = relay.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0)) + kernel_OHWoIi = relay.reshape(kernel_OHWoI, (out_channel//oc_bn, kh, kw, oc_bn, + in_channel//ic_bn, ic_bn)) + kernel_OHWoIie = relay.reshape(kernel_OHWoIi, (out_channel//oc_bn, kh, kw, oc_bn, + in_channel//ic_bn, ic_bn//n_elems, n_elems)) + kernel_OIHWioe = relay.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6)) + + # update new attrs + new_attrs['channels'] = out_channel + new_attrs['data_layout'] = 'NCHW%dc' % ic_bn + new_attrs['out_layout'] = 'NCHW%dc' % oc_bn + + # Store altered operator's config. + new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), + dtype=data_dtype) + new_kernel = tvm.placeholder((out_channel // oc_bn, + in_channel // ic_bn, + kh, + kw, + ic_bn // n_elems, + oc_bn, + n_elems), dtype=kernel_dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, new_attrs['data_layout'], + new_attrs['out_layout'], out_dtype], topi_tmpl) + dispatch_ctx.update(target, new_workload, cfg) + + return relay.nn.contrib_conv2d_nchwc(data_expr, kernel_OIHWioe, **new_attrs) + elif topi_tmpl == "depthwise_conv2d_NCHWc.x86": + assert data_layout == "NCHW" and kernel_layout == "OIHW" + if cfg.is_fallback: + _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding, + out_dtype, True, data_layout) + + batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) + out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape) + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + assert channel_multiplier == 1 + + # update new attrs + new_attrs['channels'] = out_channel + new_attrs['data_layout'] = 'NCHW%dc' % ic_bn + new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn + new_attrs['out_layout'] = 'NCHW%dc' % oc_bn + + # Store altered operator's config. + new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), + dtype=data_dtype) + new_kernel = tvm.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel_dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, new_attrs['data_layout'], + new_attrs['out_layout'], out_dtype], topi_tmpl) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs) return None + + """ groups = attrs.get_int("groups") padding = attrs.get_int_tuple("padding") strides = attrs.get_int_tuple("strides") @@ -91,7 +212,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): workload = autotvm.task.args_to_workload( [data_tensor, kernel_tensor, strides, padding, dilation, data_layout, out_dtype], conv2d) - print('alter_conv2d_layout', workload) + cfg = dispatch_ctx.query(target, workload) if cfg.is_fallback: if is_int8_hw_support(data_dtype, kernel_dtype): @@ -172,6 +293,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): dispatch_ctx.update(target, new_workload, cfg) return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) + """ @conv2d_legalize.register("cpu") diff --git a/topi/python/topi/x86/depthwise_conv2d.py b/topi/python/topi/x86/depthwise_conv2d.py index 60b5afefe03bd..e586975b5d8b5 100644 --- a/topi/python/topi/x86/depthwise_conv2d.py +++ b/topi/python/topi/x86/depthwise_conv2d.py @@ -104,13 +104,15 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, if len(data.shape) == 5: batch, in_channel_chunk, in_height, in_width, in_channel_block = get_const_tuple(data.shape) - out_channel_chunk, _, filter_height, filter_width, __, out_channel_block \ + out_channel_chunk, cm_chunk, filter_height, filter_width, cm_block, out_channel_block \ = get_const_tuple(kernel.shape) in_channel = in_channel_chunk * in_channel_block out_channel = out_channel_chunk * out_channel_block + channel_multiplier = cm_chunk * cm_block else: batch, in_channel, in_height, in_width = get_const_tuple(data.shape) - out_channel, _, filter_height, filter_width = get_const_tuple(kernel.shape) + out_channel, channel_multiplier, filter_height, filter_width = get_const_tuple(kernel.shape) + assert channel_multiplier == 1 strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) HSTR, WSTR = strides @@ -119,8 +121,6 @@ def depthwise_conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) assert (dh, dw) == (1, 1), "Does not support dilation" - channel_multiplier = out_channel // in_channel - out_height = (in_height - filter_height + pad_top + pad_down) // HSTR + 1 out_width = (in_width - filter_width + pad_left + pad_right) // WSTR + 1 diff --git a/tutorials/autotvm/tune_relay_x86.py b/tutorials/autotvm/tune_relay_x86.py index a9a774d589873..eec7b79a89ade 100644 --- a/tutorials/autotvm/tune_relay_x86.py +++ b/tutorials/autotvm/tune_relay_x86.py @@ -136,20 +136,20 @@ def tune_kernels(tasks, prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) # converting conv2d tasks to conv2d_NCHWc tasks - task_name = tsk.workload[0] - if task_name.startswith('conv2d'): - #func_create = 'topi_x86_conv2d_NCHWc' - task_name = "conv2d_NCHWc.x86" - else: - continue + # op_name = tsk.workload[0] + # if op_name == 'conv2d': + # func_create = 'topi_x86_conv2d_NCHWc' # elif op_name == 'depthwise_conv2d_nchw': # func_create = 'topi_x86_depthwise_conv2d_NCHWc_from_nchw' # else: # raise ValueError("Tuning {} is not supported on x86".format(op_name)) - - task = autotvm.task.create(task_name, args=tsk.args, - target=target, template_key='direct') - task.workload = tsk.workload + # + # task = autotvm.task.create(func_create, args=tsk.args, + # target=target, template_key='direct') + # task.workload = tsk.workload + task = tsk + if 'dense' in task.name: + continue # create tuner if tuner == 'xgb' or tuner == 'xgb-rank':