From 7702e488ca9e2f75040e03e55d131b5dae0181a5 Mon Sep 17 00:00:00 2001 From: Yida Wang Date: Tue, 25 Feb 2020 13:14:58 -0800 Subject: [PATCH] [Fix] remove unnecessary spliting in the cached chunk (#4935) * remove unnecessary spliting in the cached chunk * remove unnecessary spliting in the cached chunk --- topi/python/topi/x86/depthwise_conv2d.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/topi/python/topi/x86/depthwise_conv2d.py b/topi/python/topi/x86/depthwise_conv2d.py index 2aa5e748e5c7..70b30fea8c51 100644 --- a/topi/python/topi/x86/depthwise_conv2d.py +++ b/topi/python/topi/x86/depthwise_conv2d.py @@ -223,12 +223,12 @@ def _schedule_depthwise_conv2d_NCHWc_impl(s, cfg, data_vec, kernel_vec, conv_out s[C].parallel(parallel_axis) s[CC].compute_at(s[C], ow_chunk) + # the ow axis in the cached block CC is the ow_block in C _, ic_chunk, oh, ow, ic_block = s[CC].op.axis kh, kw = s[CC].op.reduce_axis - ow_chunk, ow_block = s[CC].split(ow, factor=tile_ow) - s[CC].reorder(ic_chunk, oh, kh, kw, ow_block, ic_block) + s[CC].reorder(ic_chunk, oh, kh, kw, ow, ic_block) s[CC].vectorize(ic_block) - s[CC].unroll(ow_block) + s[CC].unroll(ow) if C != O: out_ndim = len(s[O].op.axis)