diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 9b6e70e056ae..96c2ed2bd7e0 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -50,7 +50,7 @@ def extract_from_program(func, params, ops, target, target_host=None): # relay op -> topi compute OP2TOPI = { tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw, - topi.nn.group_conv2d_nchw], + topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc], tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw], tvm.relay.op.nn.dense: [topi.nn.dense], tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw], diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index 66fa07510237..c35cc20343e9 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -67,6 +67,7 @@ def __init__(self): topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw", topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw", topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw", + topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc", topi.nn.dense: "topi_nn_dense", topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw", topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc", @@ -80,6 +81,7 @@ def __init__(self): topi.generic.schedule_depthwise_conv2d_nhwc], topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw], topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw], + topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc], topi.nn.dense: [topi.generic.schedule_dense], topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw], topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc], @@ -108,7 +110,6 @@ def _tracing_topi_compute(*args, **kwargs): key = (self.topi_to_task[compute_func], serialize_args(args)) if key not in self.task_collection: self.task_collection.append(key) - return compute_func.fdefault(*args) _local_scope(topi_compute) @@ -205,6 +206,15 @@ def _topi_nn_deformable_conv2d_nchw(*args, **kwargs): s = topi.generic.schedule_deformable_conv2d_nchw([C]) return s, [A, Offset, W, C] + @register("topi_nn_conv2d_NCHWc") + def _topi_nn_conv2d_NCHWc(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + C = topi.nn.conv2d_NCHWc(*args, **kwargs) + s = topi.generic.schedule_conv2d_NCHWc([C]) + return s, [A, W, C] + def reset(self, wanted_topi_funcs): """Reset task collections diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index a67f608d26dc..b7e222f74cc9 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -329,7 +329,68 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou """ # search platform specific declaration first # default declaration - raise ValueError("missing register for topi.nn.conv2d_NCHWc") + # layout and out_layout are not used here, + # we keep them for debug convenience when dumping autotvm workload + pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, + (dilated_kernel_h, + dilated_kernel_w)) + HPAD = pad_top + pad_down + WPAD = pad_left + pad_right + HSTR, WSTR = stride if isinstance(stride, (tuple, list)) else (stride, stride) + dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) + assert (dh, dw) == (1, 1), "Does not support dilation" + + n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) + in_channel = ic_chunk * ic_bn + if data.dtype == 'uint8': + oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape) + else: + oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) + num_filter = oc_chunk * oc_bn + + # output shape + out_height = (ih + 2 * HPAD - kernel_height) // HSTR + 1 + out_width = (iw + 2 * WPAD - kernel_width) // WSTR + 1 + oshape = (n, oc_chunk, out_height, out_width, oc_bn) + + # DOPAD + DOPAD = (HPAD != 0 or WPAD != 0) + if DOPAD: + data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") + else: + data_pad = data + + ic = tvm.reduce_axis((0, in_channel), name='ic') + kh = tvm.reduce_axis((0, kernel_height), name='kh') + kw = tvm.reduce_axis((0, kernel_width), name='kw') + + if data.dtype == 'uint8': + assert out_dtype == "int32", \ + "INT8 convolution requires input dtype = uint8 and output dtype=int32" + # Intel performs dot product of 2 "4" Int8 values + # Current implementation requires ic_bn to be a multiple of 4 + n_elems = 4 + assert ic_bn % n_elems == 0 + + ic_outer = tvm.reduce_axis((0, in_channel//ic_bn), name='ic_outer') + ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner') + ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') + return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: + tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw, + ic_f_inner * n_elems + ic_s_inner] + .astype(out_dtype) * + kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner, + oc_block, ic_s_inner].astype(out_dtype), + axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), + name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") + # else: fp implementation + return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: + tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw, + ic%ic_bn].astype(out_dtype) * + kernel[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, oc_block], + axis=[ic, kh, kw]), + name='conv2d_NCHWc', tag="conv2d_NCHWc") + def conv2d_winograd_weight_transform(kernel, tile_size): diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index f7bfc3520b36..e3f1ac5d02e4 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -2,6 +2,7 @@ """Conv2D schedule on x86""" import logging +import re import tvm from tvm import autotvm @@ -41,9 +42,22 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): """Create schedule configuration from input arguments""" dshape = get_const_tuple(data.shape) kshape = get_const_tuple(kernel.shape) + pat = re.compile(r'NCHW.+(\d+)c') if layout == 'NCHW': n, ic, h, w = dshape oc, _, kh, kw = kshape + elif pat.match(layout) is not None: + n, ic_chunk, h, w, ic_bn = dshape + if data.dtype == 'uint8': + oc_chunk, k_ic, kh, kw, k_ic_f, oc_bn, k_ic_s = kshape + ic = ic_chunk*ic_bn + assert ic == k_ic*k_ic_f*kic_s + else: + oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape + assert ic_chunk == k_ic_chunk + assert ic_bn == k_ic_bn + ic = ic_chunk*ic_bn + oc = oc_chunk*oc_bn else: raise ValueError("Not support this layout {} with " "schedule template.".format(layout)) @@ -258,7 +272,14 @@ def traverse(op): @autotvm.task.register("topi_x86_conv2d_NCHWc") def _topi_nn_conv2d_NCHWc(*args, **kwargs): assert not kwargs, "Do not support kwargs in template function call" - data, kernel, strides, padding, dilation, origin_layout, dtype = deserialize_args(args) + args = deserialize_args(args) + + if len(args) == 7: + data, kernel, strides, padding, dilation, origin_layout, dtype = args + else: + assert len(args) == 8 + data, kernel, strides, padding, dilation, origin_layout, out_layout, dtype = args + raw_data_shape = get_const_tuple(data.shape) raw_kernel_shape = get_const_tuple(kernel.shape)