From b444f708c9cef85b1395d55abaf31aad3fa7d9de Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 14 Jun 2022 12:05:10 -0700 Subject: [PATCH] [TOPI] Allow conv definition to have custom kernel layout --- python/tvm/relay/op/strategy/cuda.py | 9 +++ python/tvm/relay/op/strategy/generic.py | 4 + python/tvm/topi/nn/conv1d.py | 10 +-- python/tvm/topi/nn/conv2d.py | 99 +++++++++++++++---------- python/tvm/topi/nn/conv3d.py | 3 +- 5 files changed, 79 insertions(+), 46 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 072b958da213d..47b33722b115f 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -322,6 +322,15 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8), name="conv2d_NCHWc_int8.cuda", ) + elif is_auto_scheduler_enabled(): + strategy.add_implementation( + wrap_compute_conv2d( + topi.nn.conv, need_data_layout=True, need_kernel_layout=True, has_groups=True + ), + naive_schedule, + name="conv2d.cuda", + plevel=15, + ) elif target.kind.name == "cuda" and "cudnn" not in target.libs: # No TVM native kernel applicable raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout)) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 4ff7490b89ace..15bd35c809f8f 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -224,6 +224,7 @@ def schedule_bitpack(attrs, outs, target): def wrap_compute_conv2d( topi_compute, need_data_layout=False, + need_kernel_layout=False, need_out_layout=False, has_groups=False, need_auto_scheduler_layout=False, @@ -236,6 +237,7 @@ def _compute_conv2d(attrs, inputs, out_type): strides = get_const_tuple(attrs.strides) dilation = get_const_tuple(attrs.dilation) data_layout = attrs.get_str("data_layout") + kernel_layout = attrs.get_str("kernel_layout") out_layout = attrs.get_str("out_layout") out_dtype = attrs.out_dtype out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype @@ -244,6 +246,8 @@ def _compute_conv2d(attrs, inputs, out_type): args.append(attrs.groups) if need_data_layout: args.append(data_layout) + if need_kernel_layout: + args.append(kernel_layout) if need_out_layout: args.append(out_layout) args.append(out_dtype) diff --git a/python/tvm/topi/nn/conv1d.py b/python/tvm/topi/nn/conv1d.py index 0a1efa35655f9..560a342d5659f 100644 --- a/python/tvm/topi/nn/conv1d.py +++ b/python/tvm/topi/nn/conv1d.py @@ -47,17 +47,17 @@ def conv1d(data, kernel, strides=1, padding="VALID", dilation=1, layout="NCW", o out_dtype : str The output data type. If None then output is same type as input. """ - return conv(data, kernel, strides, padding, dilation, 1, layout, out_dtype) + return conv(data, kernel, strides, padding, dilation, 1, layout, "", out_dtype) def conv1d_nwc(data, kernel, strides=1, padding="VALID", dilation=1, out_dtype=None): """1D convolution in NWC layout. See :py:func:`conv` for details on parameters""" - return conv(data, kernel, strides, padding, dilation, 1, "NWC", out_dtype=out_dtype) + return conv(data, kernel, strides, padding, dilation, 1, "NWC", "", out_dtype=out_dtype) def conv1d_ncw(data, kernel, strides=1, padding="VALID", dilation=1, out_dtype=None): """1D convolution in NCW layout. See :py:func:`conv` for details on parameters""" - return conv(data, kernel, strides, padding, dilation, 1, "NCW", out_dtype=out_dtype) + return conv(data, kernel, strides, padding, dilation, 1, "NCW", "", out_dtype=out_dtype) def group_conv1d_nwc( @@ -89,7 +89,7 @@ def group_conv1d_nwc( out_dtype : str The output data type. If None then output is same type as input. """ - return conv(data, kernel, strides, padding, dilation, groups, "NWC", out_dtype=out_dtype) + return conv(data, kernel, strides, padding, dilation, groups, "NWC", "", out_dtype=out_dtype) def group_conv1d_ncw( @@ -121,4 +121,4 @@ def group_conv1d_ncw( out_dtype : str The output data type. If None then output is same type as input. """ - return conv(data, kernel, strides, padding, dilation, groups, "NCW", out_dtype=out_dtype) + return conv(data, kernel, strides, padding, dilation, groups, "NCW", "", out_dtype=out_dtype) diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 5db752f6d54f0..32bf3f703f347 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -89,7 +89,7 @@ def conv2d(input, filter, strides, padding, dilation, layout="NCHW", out_dtype=N """ # search platform specific declaration first # default declaration - return conv(input, filter, strides, padding, dilation, 1, layout, out_dtype) + return conv(input, filter, strides, padding, dilation, 1, layout, "", out_dtype) @tvm.target.generic_func @@ -239,7 +239,7 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None): Output : tvm.te.Tensor 4-D with shape [batch, out_channel, out_height, out_width] """ - return conv(Input, Filter, stride, padding, dilation, 1, "NCHW", out_dtype=out_dtype) + return conv(Input, Filter, stride, padding, dilation, 1, "NCHW", "", out_dtype=out_dtype) def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None): @@ -269,7 +269,7 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None): output : tvm.te.Tensor 4-D with shape [out_height, out_width, out_channel, batch] """ - return conv(Input, Filter, stride, padding, dilation, 1, "HWCN", out_dtype=out_dtype) + return conv(Input, Filter, stride, padding, dilation, 1, "HWCN", "", out_dtype=out_dtype) def conv2d_nhwc( @@ -325,6 +325,7 @@ def conv2d_nhwc( dilation, 1, "NHWC", + "", out_dtype, auto_scheduler_rewritten_layout, meta_schedule_original_shape, @@ -708,7 +709,7 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp Output : tvm.te.Tensor 4-D with shape [batch, out_channel, out_height, out_width] """ - return conv(Input, Filter, stride, padding, dilation, groups, "NCHW", out_dtype=out_dtype) + return conv(Input, Filter, stride, padding, dilation, groups, "NCHW", "", out_dtype=out_dtype) def conv( @@ -718,7 +719,8 @@ def conv( padding: Union[int, Sequence[int]], dilation: Union[int, Sequence[int]], groups: int, - order: str, + data_layout: str, + kernel_layout: str = "", out_dtype: Union[str, None] = None, auto_scheduler_rewritten_layout: Optional[str] = None, meta_schedule_original_shape=None, @@ -731,11 +733,11 @@ def conv( Parameters ---------- inp : tvm.te.Tensor - N-D with shape [batch, in_channel, in_height, in_width, ...] ordered by `order` + N-D with shape [batch, in_channel, in_height, in_width, ...] in `data_layout` filt : tvm.te.Tensor - N-D with shape [num_filter, in_channel // groups, filter_height, filter_width, ...] - for NCHW or [filter_height, filter_width, ..., in_channel // groups, num_filter] for NHWC + N-D with shape [num_filter, in_channel // groups, filter_height, filter_width, ...] in + `kernel_layout` stride : int or a list/tuple of dim ints (where dim=2 for NCHW, dim=1 for NCH, etc.) @@ -753,10 +755,16 @@ def conv( groups : int number of groups - order : str - Ordering of dimensions. N indicates batch dimension, C indicates + data_layout : str + Layout of the input. N indicates batch dimension, C indicates channels, any other character indicates HW (or H or HWD for 1D and 3D). + kernel_layout: Optional[str] + Layout of the filter. I indicates input channels, O indicates output channels, + any other character indicates HW dimension of the filter (or H or HWD for 1D and 3D). + If kernel_layout is empty, use data_layout to infer the default kernel_layout. Default + kernel_layout is OHWI for NCHW data layout, HWIO for NHWC data layout. + out_dtype : str Elements are converted to this type before elementwise multiplication and summation. @@ -775,7 +783,7 @@ def conv( Returns ------- Output : tvm.te.Tensor - N-D with shape [batch, out_channel, out_height, out_width, ...] ordered by `order`. + N-D with shape [batch, out_channel, out_height, out_width, ...] in `data_layout` """ dim = len(inp.shape) - 2 if out_dtype is None: @@ -792,30 +800,41 @@ def conv( else: dilations = list(dilation) - # transform from order to NCHW - permutation_to = [order.find("N"), order.find("C")] + [ - x.span()[0] for x in re.finditer("[^NC]", order) + # transform from data_layout to NCHW + data_permutation_to = [data_layout.find("N"), data_layout.find("C")] + [ + x.span()[0] for x in re.finditer("[^NC]", data_layout) ] - # transform from NCHW to order - permutation_from = np.argsort(permutation_to) - # transform from CHW to order - permutation_from_reductions = permutation_from[1:].copy() - permutation_from_reductions[permutation_from_reductions > permutation_from[0]] -= 1 - - # kernel permutation, if C appears before HW then num_filter is first, otherwise it is last - # tkonolige: I don't really understand kernel ordering for NHWC, it seems - # like num_filters should match the N dimension - if order.find("C") < re.search("[^NC]", order).span()[0]: - permutation_to_kernel = [0, 1] + list(range(2, dim + 2)) + # transform from NCHW to data_layout + data_permutation_from = np.argsort(data_permutation_to) + # transform from CHW to data_layout + data_permutation_from_reductions = data_permutation_from[1:].copy() + data_permutation_from_reductions[ + data_permutation_from_reductions > data_permutation_from[0] + ] -= 1 + + if kernel_layout == "": + # kernel permutation, if C appears before HW then num_filter is first, otherwise it is last + # tkonolige: I don't really understand kernel ordering for NHWC, it seems + # like num_filters should match the N dimension + if data_layout.find("C") < re.search("[^NC]", data_layout).span()[0]: + kernel_permutation_to = [0, 1] + list(range(2, dim + 2)) + else: + kernel_permutation_to = [dim + 1, dim] + list(range(dim)) else: - permutation_to_kernel = [dim + 1, dim] + list(range(dim)) - permutation_from_kernel = np.argsort(permutation_to_kernel) + # transform from kernel_layout to OIHW + kernel_permutation_to = [kernel_layout.find("O"), kernel_layout.find("I")] + [ + x.span()[0] for x in re.finditer("[^OI]", kernel_layout) + ] + # transform from OIHW to kernel_layout + kernel_permutation_from = np.argsort(kernel_permutation_to) if meta_schedule_original_shape: auto_scheduler.rewrite_tensor_shape(filt, meta_schedule_original_shape) - batch, in_channel, *dimensions = np.array(get_const_tuple(inp.shape))[permutation_to].tolist() + batch, in_channel, *dimensions = np.array(get_const_tuple(inp.shape))[ + data_permutation_to + ].tolist() num_filter, _, *kernel_dimensions = np.array(get_const_tuple(filt.shape))[ - permutation_to_kernel + kernel_permutation_to ].tolist() # Autoscheduler may have messed with the input layout, so we extract the @@ -841,14 +860,14 @@ def conv( ) ] # compute graph - pad_before = list(np.array([0, 0] + pad_begin)[permutation_from]) - pad_after = list(np.array([0, 0] + pad_end)[permutation_from]) + pad_before = list(np.array([0, 0] + pad_begin)[data_permutation_from]) + pad_after = list(np.array([0, 0] + pad_end)[data_permutation_from]) temp = pad(inp, pad_before, pad_after, name="pad_temp") rc = te.reduce_axis((0, in_channel // groups), name="rc") rs = [te.reduce_axis((0, k), name=f"r{i}") for i, k in zip(["y", "x", "z"], kernel_dimensions)] def compute(*args): - nn, ff, *dim_indices = list(np.array(args)[permutation_to]) + nn, ff, *dim_indices = list(np.array(args)[data_permutation_to]) if groups == 1: simplified_channel_index = rc @@ -864,25 +883,25 @@ def compute(*args): di * stride + r * dil for di, stride, r, dil in zip(dim_indices, strides, rs, dilations) ] - )[permutation_from] + )[data_permutation_from] ) ).astype(out_dtype) - * filt.__getitem__(tuple(np.array([ff, rc] + rs)[permutation_from_kernel])).astype( + * filt.__getitem__(tuple(np.array([ff, rc] + rs)[kernel_permutation_from])).astype( out_dtype ), # Schedules depend on reduction axes being in the same order as the # layout, so we reorder here. - axis=np.array([rc, *rs])[permutation_from_reductions].tolist(), + axis=np.array([rc, *rs])[data_permutation_from_reductions].tolist(), ) out = te.compute( - list(np.array([batch, out_channel] + out_dimensions)[permutation_from]), + list(np.array([batch, out_channel] + out_dimensions)[data_permutation_from]), compute, # tag is expected to be lowercase - tag=f"{'group_' if groups > 1 else ''}conv{dim}d_{order.lower()}", - name=f"{'group_' if groups > 1 else ''}conv{dim}d_{order.lower()}", + tag=f"{'group_' if groups > 1 else ''}conv{dim}d_{data_layout.lower()}", + name=f"{'group_' if groups > 1 else ''}conv{dim}d_{data_layout.lower()}", attrs={"layout_free_placeholders": [filt]} if auto_scheduler_should_rewrite_layout else {}, - varargs_names=list(np.array(["nn", "ff", "yy", "xx", "zz"])[permutation_from]), + varargs_names=list(np.array(["nn", "ff", "yy", "xx", "zz"])[data_permutation_from]), ) # if we used autoscheduler's changed layout we need to rewrite the ordering # of the output dimensions @@ -924,7 +943,7 @@ def group_conv2d_nhwc(Input, Filter, stride, padding, dilation, groups, out_dtyp Output : tvm.te.Tensor 4-D with shape [batch, out_height, out_width, out_channel] """ - return conv(Input, Filter, stride, padding, dilation, groups, "NHWC", out_dtype=out_dtype) + return conv(Input, Filter, stride, padding, dilation, groups, "NHWC", "", out_dtype=out_dtype) def unpack_NCHWc_to_nchw(packed_out, out_dtype): diff --git a/python/tvm/topi/nn/conv3d.py b/python/tvm/topi/nn/conv3d.py index 591c643a95c25..e3e762be47615 100644 --- a/python/tvm/topi/nn/conv3d.py +++ b/python/tvm/topi/nn/conv3d.py @@ -53,7 +53,7 @@ def conv3d_ncdhw(Input, Filter, stride, padding, dilation, groups, out_dtype=Non Output : tvm.te.Tensor 5-D with shape [batch, out_channel, out_depth, out_height, out_width] """ - return conv(Input, Filter, stride, padding, dilation, groups, "NCDHW", out_dtype) + return conv(Input, Filter, stride, padding, dilation, groups, "NCDHW", "", out_dtype) def conv3d_ndhwc( @@ -111,6 +111,7 @@ def conv3d_ndhwc( dilation, groups, "NDHWC", + "", out_dtype, auto_scheduler_rewritten_layout, meta_schedule_origin_shape,