diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index ce33d3ed3c0c..8b3ba35e92ab 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -182,12 +182,15 @@ def _topi_nn_conv2d(*args, **kwargs): args = deserialize_args(args) A, W = args[:2] layout = args[-2] - assert layout == 'NCHW' or layout == 'HWCN', "only support NCHW/HWCN currently" C = topi.nn.conv2d(*args, **kwargs) if layout == 'NCHW': s = topi.generic.schedule_conv2d_nchw([C]) - else: + elif layout == 'HWCN': s = topi.generic.schedule_conv2d_hwcn([C]) + elif layout == 'NHWC': + s = topi.generic.schedule_conv2d_nhwc([C]) + else: + raise ValueError("Unsupported layout {}".format(layout)) return s, [A, W, C] @register("topi_nn_depthwise_conv2d_nchw") diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 6e95de579c19..673307a62925 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -24,7 +24,8 @@ from tvm import autotvm import tvm.contrib.nnpack -from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform, \ +from ..generic import schedule_conv2d_nchw, schedule_conv2d_nhwc, \ + schedule_conv2d_winograd_without_weight_transform, \ schedule_conv2d_winograd_nnpack_without_weight_transform from ..util import traverse_inline, get_const_tuple from ..nn import dilate, pad, conv2d, conv2d_alter_layout, \ @@ -34,7 +35,9 @@ from ..nn.util import get_const_int, get_pad_tuple from ..nn.winograd_util import winograd_transform_matrices from .conv2d_spatial_pack import conv2d_spatial_pack_nchw, \ - schedule_conv2d_spatial_pack_nchw + conv2d_spatial_pack_nhwc, \ + schedule_conv2d_spatial_pack_nchw, \ + schedule_conv2d_spatial_pack_nhwc logger = logging.getLogger('topi') @@ -78,6 +81,9 @@ def conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, layout, out_dt if layout == 'NCHW': return conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2) + elif layout == 'NHWC': + return conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, + dilation, out_dtype) else: raise ValueError("Unsupported layout {}".format(layout)) @@ -136,6 +142,34 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s +@autotvm.register_topi_schedule(schedule_conv2d_nhwc, 'arm_cpu', ['direct']) +def schedule_conv2d_nhwc_arm_cpu(cfg, outs): + """TOPI schedule callback for conv2d + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + outs: Array of Tensor + The computation graph description of conv2d + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for conv2d. + """ + s = tvm.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'spatial_conv_output_NHWC' in op.tag: + schedule_conv2d_spatial_pack_nhwc(cfg, s, op, outs[0]) + + traverse_inline(s, outs[0].op, _callback) + return s + + @autotvm.register_topi_compute(conv2d, 'arm_cpu', ['winograd']) def conv2d_arm_cpu_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): """ TOPI compute callback. Use winograd template """ diff --git a/topi/python/topi/arm_cpu/conv2d_spatial_pack.py b/topi/python/topi/arm_cpu/conv2d_spatial_pack.py index b566c98a4ec5..350a0227ef48 100644 --- a/topi/python/topi/arm_cpu/conv2d_spatial_pack.py +++ b/topi/python/topi/arm_cpu/conv2d_spatial_pack.py @@ -196,3 +196,163 @@ def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec, s[kernel_vec].parallel(co) return s + +def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Spatial pack compute for Conv2d NHWC""" + out_dtype = out_dtype or data.dtype + + N, IH, IW, IC = get_const_tuple(data.shape) + assert len(kernel.shape) == 4, "AlterOpLayout not enabled for NHWC yet" + KH, KW, _, OC = get_const_tuple(kernel.shape) + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + dilated_kernel_h = (KH - 1) * dilation_h + 1 + dilated_kernel_w = (KW - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = \ + get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w)) + HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) + + OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 + OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 + data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0]) + + # ==================== define configuration space ==================== + n, oc, oh, ow = cfg.axis(N), cfg.axis(OC), cfg.axis(OH), cfg.axis(OW) + ic, kh, kw = cfg.reduce_axis(IC), cfg.reduce_axis(KH), cfg.reduce_axis(KW) + + oco, oci = cfg.define_split('tile_co', oc, num_outputs=2) + oho, ohi = cfg.define_split('tile_oh', oh, num_outputs=2) + owo, owi = cfg.define_split('tile_ow', ow, num_outputs=2) + + cfg.define_reorder('reorder_conv', + [n, oho, owo, oco, kh, kw, ic, ohi, owi, oci], + policy='candidate', candidate=[ + [n, oho, owo, oco, kh, kw, ic, ohi, owi, oci], + [n, oho, owo, oco, ohi, kh, kw, ic, owi, oci], + [n, oho, owo, oco, ohi, kh, kw, owi, ic, oci], + [n, oho, owo, ohi, oco, kh, kw, owi, ic, oci]]) + + cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll') + cfg.define_annotate("ann_spatial", [ohi, owi, oci], policy='try_unroll_vec') + # ==================================================================== + + OCI = cfg['tile_co'].size[-1] + OHI = cfg['tile_oh'].size[-1] + OWI = cfg['tile_ow'].size[-1] + OCO = OC // OCI + OHO = OH // OHI + OWO = OW // OWI + + kvshape = (OCO, KH, KW, IC, OCI) + ovshape = (N, OHO, OWO, OCO, OHI, OWI, OCI) + oshape = (N, OH, OW, OC) + + if dilation_h != 1 or dilation_w != 1: + # undilate input data + dvshape = (N, OHO, OWO, KH, KW, IC, OHI, OWI) + data_vec = tvm.compute(dvshape, lambda n, oho, owo, kh, kw, ic, ohi, owi: + data_pad[n][(oho*OHI+ohi)*HSTR+kh*dilation_h] + [(owo*OWI+owi)*WSTR+kw*dilation_w][ic], + name='data_vec_undilated') + else: + dvshape = (N, OHO, OWO, KH + (OHI-1)*HSTR, KW + (OWI-1)*WSTR, IC) + data_vec = tvm.compute(dvshape, lambda n, oho, owo, ohi, owi, ic: + data_pad[n][oho*OHI*HSTR+ohi][owo*OWI*WSTR+owi][ic], + name='data_vec') + kernel_vec = tvm.compute(kvshape, lambda oco, kh, kw, ic, oci: \ + kernel[kh][kw][ic][oco*OCI+oci], + name='kernel_vec') + + ic = tvm.reduce_axis((0, IC), name='ic') + kh = tvm.reduce_axis((0, KH), name='kh') + kw = tvm.reduce_axis((0, KW), name='kw') + + if dilation_h != 1 or dilation_w != 1: + conv = tvm.compute(ovshape, lambda n, oho, owo, oco, ohi, owi, oci: \ + tvm.sum(data_vec[n, oho, owo, kh, kw, ohi, owi, ic].astype(out_dtype) * + kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype), + axis=[ic, kh, kw]), name='conv') + else: + conv = tvm.compute(ovshape, lambda n, oho, owo, oco, ohi, owi, oci: \ + tvm.sum(data_vec[n, oho, owo, ohi*HSTR+kh, owi*WSTR+kw, ic].astype(out_dtype) * + kernel_vec[oco, kh, kw, ic, oci].astype(out_dtype), + axis=[ic, kh, kw]), name='conv') + + idiv = tvm.indexdiv + imod = tvm.indexmod + output = tvm.compute(oshape, lambda n, oho, owo, oc: + conv[n][idiv(oho, OHI)][idiv(owo, OWI)][idiv(oc, OCI)]\ + [imod(oho, OHI)][imod(owo, OWI)][imod(oc, OCI)], + name='output_unpack', tag='spatial_conv_output_NHWC') + return output + +def schedule_conv2d_spatial_pack_nhwc(cfg, s, op, output): + """Spatial Pack schedule for Conv2d NHWC""" + unpack = op.output(0) + conv = unpack.op.input_tensors[0] + data_vec = conv.op.input_tensors[0] + kernel_vec = conv.op.input_tensors[1] + data_pad = data_vec.op.input_tensors[0] + OHI = cfg['tile_oh'].size[-1] + OWI = cfg['tile_ow'].size[-1] + OCI = cfg['tile_co'].size[-1] + + # schedule unpack/output + if output != unpack: + s[unpack].compute_inline() + n, oh, ow, oc = s[output].op.axis + oco, oci = cfg['tile_co'].apply(s, output, oc) + oho, ohi = cfg['tile_oh'].apply(s, output, oh) + owo, owi = cfg['tile_ow'].apply(s, output, ow) + s[output].reorder(n, oho, owo, oco, ohi, owi, oci) + cfg['ann_spatial'].apply(s, output, [ohi, owi, oci], axis_lens=[OHI, OWI, OCI], + max_unroll=16, cfg=cfg) + cfg.define_knob('compat', [0, 1, 2]) + if cfg['compat'].val < 2: + compat_axis = [owo, oco][cfg['compat'].val] # pylint: disable=R1706 + s[conv].compute_at(s[output], compat_axis) + paxis = s[output].fuse(n, oho) + s[output].parallel(paxis) + + # schedule conv + n, oho, owo, oco, ohi, owi, oci = s[conv].op.axis + ic, kh, kw = s[conv].op.reduce_axis + cfg['reorder_conv'].apply(s, conv, [n, oho, owo, oco, kh, kw, ohi, owi, ic, oci]) + cfg['ann_reduce'].apply(s, conv, [kh, kw], + axis_lens=[get_const_int(kh.dom.extent), + get_const_int(kw.dom.extent)], + max_unroll=16, + cfg=cfg) + cfg['ann_spatial'].apply(s, conv, [ohi, owi, oci], axis_lens=[OHI, OWI, OCI], + max_unroll=16, cfg=cfg) + if cfg['compat'].val < 2: + compat_axis = [owo, oco][cfg['compat'].val] # pylint: disable=R1706 + s[kernel_vec].compute_at(s[conv], compat_axis) + s[data_vec].compute_at(s[conv], compat_axis) + + # schedule kernel pack + oco, kh, kw, ic, oci = kernel_vec.op.axis + s[kernel_vec].vectorize(oci) + s[kernel_vec].unroll(ic) + if cfg['compat'].val == 2: + s[kernel_vec].parallel(oco) + + # schedule data pack + if data_vec.op.name == 'data_vec_undilated': + n, oho, owo, kh, kw, ic, ohi, owi = s[data_vec].op.axis + s[data_vec].vectorize(owi) + s[data_vec].unroll(ohi) + else: + n, oho, owo, ohi, owi, ic = s[data_vec].op.axis + s[data_vec].vectorize(ic) + s[data_vec].unroll(owi) + if cfg['compat'].val == 2: + paxis = s[data_vec].fuse(n, oho) + s[data_vec].parallel(paxis) + + return s