From 06b2ede4f44c10d6658adb291f945cc4ab08bda4 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 30 Mar 2020 19:02:07 -0700 Subject: [PATCH] [TOPI] Setting workload correctly for Depthwise conv ARM. (#5182) --- topi/python/topi/arm_cpu/conv2d_alter_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/topi/python/topi/arm_cpu/conv2d_alter_op.py b/topi/python/topi/arm_cpu/conv2d_alter_op.py index 3a22611ed128..553239b6c426 100644 --- a/topi/python/topi/arm_cpu/conv2d_alter_op.py +++ b/topi/python/topi/arm_cpu/conv2d_alter_op.py @@ -154,14 +154,14 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): if topi_tmpl == "depthwise_conv2d_nchw_spatial_pack.arm_cpu": assert data_layout == "NCHW" and kernel_layout == "OIHW" N, CI, H, W = get_const_tuple(data.shape) - CO, _, KH, KW = get_const_tuple(kernel.shape) + CO, M, KH, KW = get_const_tuple(kernel.shape) VC = cfg['tile_co'].size[-1] new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1]) # Store the same config for the altered operator (workload) new_data = data - new_kernel = te.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype) + new_kernel = te.placeholder((idxd(CO, VC), M, KH, KW, VC), dtype=kernel.dtype) new_workload = autotvm.task.args_to_workload( [new_data, new_kernel, strides, padding, dilation, out_dtype], "depthwise_conv2d_nchw_spatial_pack.arm_cpu")