From 657cffd208f41f5afed4f95433affe269157ab72 Mon Sep 17 00:00:00 2001 From: Yanming Wang Date: Fri, 24 Jul 2020 16:00:09 -0700 Subject: [PATCH] [AutoTVM][BugFix] Fix autotvm on the conv2d_nchw_winograd.mali operator (#6130) * [AutoTVM] Fix conv2d_nchw_winograd.mali * Fix pylint error Co-authored-by: Yanming Wang --- python/tvm/autotvm/task/task.py | 10 +++++++--- topi/python/topi/mali/conv2d.py | 3 +-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index b7cd6f2b04ede..3942599e2cb1f 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -216,19 +216,23 @@ def __call__(self, *args, **kwargs): def _default_func(self, *args, **kwargs): assert callable(self.fcompute) and callable(self.fschedule) out = self.fcompute(*args, **kwargs) - arg_bufs = [out] + self.get_inputs(out) + arg_bufs = [out] + self._get_inputs(out) s = self.fschedule([out]) return s, arg_bufs - def get_inputs(self, out): + @staticmethod + def _get_inputs(out): inputs = [] queue = [out] + hash_set = set() while queue: t = queue.pop(0) if isinstance(t.op, tensor.PlaceholderOp): inputs.append(t) else: - queue.extend(t.op.input_tensors) + input_tensors = [t for t in t.op.input_tensors if t not in hash_set] + queue.extend(input_tensors) + hash_set.update(input_tensors) return inputs def _register_task_compute(name, func=None): diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index ed19326749643..f2b26ee09faf6 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -276,8 +276,7 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til [(b*bnb+bb) % nW * m + nu], tvm.tir.const(0, data_pad.dtype)), name='d') if autotvm.GLOBAL_SCOPE.in_tuning: - VC = cfg['tile_k'].size[-1] - kvshape = (KH + tile_size - 1, KW + tile_size - 1, tvm.tir.indexdiv(CO, VC), CI, VC) + kvshape = (alpha, alpha, CO // bna, CI, bna) U = tvm.te.placeholder(kvshape, kernel.dtype, name="U") else: # transform kernel