Skip to content

Commit

Permalink
[TOPI] Allow conv definition to have custom kernel layout
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jun 28, 2022
1 parent 6c433d2 commit b444f70
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 46 deletions.
9 changes: 9 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,15 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.cuda",
)
elif is_auto_scheduler_enabled():
strategy.add_implementation(
wrap_compute_conv2d(
topi.nn.conv, need_data_layout=True, need_kernel_layout=True, has_groups=True
),
naive_schedule,
name="conv2d.cuda",
plevel=15,
)
elif target.kind.name == "cuda" and "cudnn" not in target.libs:
# No TVM native kernel applicable
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def schedule_bitpack(attrs, outs, target):
def wrap_compute_conv2d(
topi_compute,
need_data_layout=False,
need_kernel_layout=False,
need_out_layout=False,
has_groups=False,
need_auto_scheduler_layout=False,
Expand All @@ -236,6 +237,7 @@ def _compute_conv2d(attrs, inputs, out_type):
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
data_layout = attrs.get_str("data_layout")
kernel_layout = attrs.get_str("kernel_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
Expand All @@ -244,6 +246,8 @@ def _compute_conv2d(attrs, inputs, out_type):
args.append(attrs.groups)
if need_data_layout:
args.append(data_layout)
if need_kernel_layout:
args.append(kernel_layout)
if need_out_layout:
args.append(out_layout)
args.append(out_dtype)
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/topi/nn/conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,17 @@ def conv1d(data, kernel, strides=1, padding="VALID", dilation=1, layout="NCW", o
out_dtype : str
The output data type. If None then output is same type as input.
"""
return conv(data, kernel, strides, padding, dilation, 1, layout, out_dtype)
return conv(data, kernel, strides, padding, dilation, 1, layout, "", out_dtype)


def conv1d_nwc(data, kernel, strides=1, padding="VALID", dilation=1, out_dtype=None):
"""1D convolution in NWC layout. See :py:func:`conv` for details on parameters"""
return conv(data, kernel, strides, padding, dilation, 1, "NWC", out_dtype=out_dtype)
return conv(data, kernel, strides, padding, dilation, 1, "NWC", "", out_dtype=out_dtype)


def conv1d_ncw(data, kernel, strides=1, padding="VALID", dilation=1, out_dtype=None):
"""1D convolution in NCW layout. See :py:func:`conv` for details on parameters"""
return conv(data, kernel, strides, padding, dilation, 1, "NCW", out_dtype=out_dtype)
return conv(data, kernel, strides, padding, dilation, 1, "NCW", "", out_dtype=out_dtype)


def group_conv1d_nwc(
Expand Down Expand Up @@ -89,7 +89,7 @@ def group_conv1d_nwc(
out_dtype : str
The output data type. If None then output is same type as input.
"""
return conv(data, kernel, strides, padding, dilation, groups, "NWC", out_dtype=out_dtype)
return conv(data, kernel, strides, padding, dilation, groups, "NWC", "", out_dtype=out_dtype)


def group_conv1d_ncw(
Expand Down Expand Up @@ -121,4 +121,4 @@ def group_conv1d_ncw(
out_dtype : str
The output data type. If None then output is same type as input.
"""
return conv(data, kernel, strides, padding, dilation, groups, "NCW", out_dtype=out_dtype)
return conv(data, kernel, strides, padding, dilation, groups, "NCW", "", out_dtype=out_dtype)
99 changes: 59 additions & 40 deletions python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def conv2d(input, filter, strides, padding, dilation, layout="NCHW", out_dtype=N
"""
# search platform specific declaration first
# default declaration
return conv(input, filter, strides, padding, dilation, 1, layout, out_dtype)
return conv(input, filter, strides, padding, dilation, 1, layout, "", out_dtype)


@tvm.target.generic_func
Expand Down Expand Up @@ -239,7 +239,7 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
Output : tvm.te.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
return conv(Input, Filter, stride, padding, dilation, 1, "NCHW", out_dtype=out_dtype)
return conv(Input, Filter, stride, padding, dilation, 1, "NCHW", "", out_dtype=out_dtype)


def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None):
Expand Down Expand Up @@ -269,7 +269,7 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None):
output : tvm.te.Tensor
4-D with shape [out_height, out_width, out_channel, batch]
"""
return conv(Input, Filter, stride, padding, dilation, 1, "HWCN", out_dtype=out_dtype)
return conv(Input, Filter, stride, padding, dilation, 1, "HWCN", "", out_dtype=out_dtype)


def conv2d_nhwc(
Expand Down Expand Up @@ -325,6 +325,7 @@ def conv2d_nhwc(
dilation,
1,
"NHWC",
"",
out_dtype,
auto_scheduler_rewritten_layout,
meta_schedule_original_shape,
Expand Down Expand Up @@ -708,7 +709,7 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp
Output : tvm.te.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
return conv(Input, Filter, stride, padding, dilation, groups, "NCHW", out_dtype=out_dtype)
return conv(Input, Filter, stride, padding, dilation, groups, "NCHW", "", out_dtype=out_dtype)


def conv(
Expand All @@ -718,7 +719,8 @@ def conv(
padding: Union[int, Sequence[int]],
dilation: Union[int, Sequence[int]],
groups: int,
order: str,
data_layout: str,
kernel_layout: str = "",
out_dtype: Union[str, None] = None,
auto_scheduler_rewritten_layout: Optional[str] = None,
meta_schedule_original_shape=None,
Expand All @@ -731,11 +733,11 @@ def conv(
Parameters
----------
inp : tvm.te.Tensor
N-D with shape [batch, in_channel, in_height, in_width, ...] ordered by `order`
N-D with shape [batch, in_channel, in_height, in_width, ...] in `data_layout`
filt : tvm.te.Tensor
N-D with shape [num_filter, in_channel // groups, filter_height, filter_width, ...]
for NCHW or [filter_height, filter_width, ..., in_channel // groups, num_filter] for NHWC
N-D with shape [num_filter, in_channel // groups, filter_height, filter_width, ...] in
`kernel_layout`
stride : int or a list/tuple of dim ints
(where dim=2 for NCHW, dim=1 for NCH, etc.)
Expand All @@ -753,10 +755,16 @@ def conv(
groups : int
number of groups
order : str
Ordering of dimensions. N indicates batch dimension, C indicates
data_layout : str
Layout of the input. N indicates batch dimension, C indicates
channels, any other character indicates HW (or H or HWD for 1D and 3D).
kernel_layout: Optional[str]
Layout of the filter. I indicates input channels, O indicates output channels,
any other character indicates HW dimension of the filter (or H or HWD for 1D and 3D).
If kernel_layout is empty, use data_layout to infer the default kernel_layout. Default
kernel_layout is OHWI for NCHW data layout, HWIO for NHWC data layout.
out_dtype : str
Elements are converted to this type before elementwise multiplication
and summation.
Expand All @@ -775,7 +783,7 @@ def conv(
Returns
-------
Output : tvm.te.Tensor
N-D with shape [batch, out_channel, out_height, out_width, ...] ordered by `order`.
N-D with shape [batch, out_channel, out_height, out_width, ...] in `data_layout`
"""
dim = len(inp.shape) - 2
if out_dtype is None:
Expand All @@ -792,30 +800,41 @@ def conv(
else:
dilations = list(dilation)

# transform from order to NCHW
permutation_to = [order.find("N"), order.find("C")] + [
x.span()[0] for x in re.finditer("[^NC]", order)
# transform from data_layout to NCHW
data_permutation_to = [data_layout.find("N"), data_layout.find("C")] + [
x.span()[0] for x in re.finditer("[^NC]", data_layout)
]
# transform from NCHW to order
permutation_from = np.argsort(permutation_to)
# transform from CHW to order
permutation_from_reductions = permutation_from[1:].copy()
permutation_from_reductions[permutation_from_reductions > permutation_from[0]] -= 1

# kernel permutation, if C appears before HW then num_filter is first, otherwise it is last
# tkonolige: I don't really understand kernel ordering for NHWC, it seems
# like num_filters should match the N dimension
if order.find("C") < re.search("[^NC]", order).span()[0]:
permutation_to_kernel = [0, 1] + list(range(2, dim + 2))
# transform from NCHW to data_layout
data_permutation_from = np.argsort(data_permutation_to)
# transform from CHW to data_layout
data_permutation_from_reductions = data_permutation_from[1:].copy()
data_permutation_from_reductions[
data_permutation_from_reductions > data_permutation_from[0]
] -= 1

if kernel_layout == "":
# kernel permutation, if C appears before HW then num_filter is first, otherwise it is last
# tkonolige: I don't really understand kernel ordering for NHWC, it seems
# like num_filters should match the N dimension
if data_layout.find("C") < re.search("[^NC]", data_layout).span()[0]:
kernel_permutation_to = [0, 1] + list(range(2, dim + 2))
else:
kernel_permutation_to = [dim + 1, dim] + list(range(dim))
else:
permutation_to_kernel = [dim + 1, dim] + list(range(dim))
permutation_from_kernel = np.argsort(permutation_to_kernel)
# transform from kernel_layout to OIHW
kernel_permutation_to = [kernel_layout.find("O"), kernel_layout.find("I")] + [
x.span()[0] for x in re.finditer("[^OI]", kernel_layout)
]
# transform from OIHW to kernel_layout
kernel_permutation_from = np.argsort(kernel_permutation_to)

if meta_schedule_original_shape:
auto_scheduler.rewrite_tensor_shape(filt, meta_schedule_original_shape)
batch, in_channel, *dimensions = np.array(get_const_tuple(inp.shape))[permutation_to].tolist()
batch, in_channel, *dimensions = np.array(get_const_tuple(inp.shape))[
data_permutation_to
].tolist()
num_filter, _, *kernel_dimensions = np.array(get_const_tuple(filt.shape))[
permutation_to_kernel
kernel_permutation_to
].tolist()

# Autoscheduler may have messed with the input layout, so we extract the
Expand All @@ -841,14 +860,14 @@ def conv(
)
]
# compute graph
pad_before = list(np.array([0, 0] + pad_begin)[permutation_from])
pad_after = list(np.array([0, 0] + pad_end)[permutation_from])
pad_before = list(np.array([0, 0] + pad_begin)[data_permutation_from])
pad_after = list(np.array([0, 0] + pad_end)[data_permutation_from])
temp = pad(inp, pad_before, pad_after, name="pad_temp")
rc = te.reduce_axis((0, in_channel // groups), name="rc")
rs = [te.reduce_axis((0, k), name=f"r{i}") for i, k in zip(["y", "x", "z"], kernel_dimensions)]

def compute(*args):
nn, ff, *dim_indices = list(np.array(args)[permutation_to])
nn, ff, *dim_indices = list(np.array(args)[data_permutation_to])

if groups == 1:
simplified_channel_index = rc
Expand All @@ -864,25 +883,25 @@ def compute(*args):
di * stride + r * dil
for di, stride, r, dil in zip(dim_indices, strides, rs, dilations)
]
)[permutation_from]
)[data_permutation_from]
)
).astype(out_dtype)
* filt.__getitem__(tuple(np.array([ff, rc] + rs)[permutation_from_kernel])).astype(
* filt.__getitem__(tuple(np.array([ff, rc] + rs)[kernel_permutation_from])).astype(
out_dtype
),
# Schedules depend on reduction axes being in the same order as the
# layout, so we reorder here.
axis=np.array([rc, *rs])[permutation_from_reductions].tolist(),
axis=np.array([rc, *rs])[data_permutation_from_reductions].tolist(),
)

out = te.compute(
list(np.array([batch, out_channel] + out_dimensions)[permutation_from]),
list(np.array([batch, out_channel] + out_dimensions)[data_permutation_from]),
compute,
# tag is expected to be lowercase
tag=f"{'group_' if groups > 1 else ''}conv{dim}d_{order.lower()}",
name=f"{'group_' if groups > 1 else ''}conv{dim}d_{order.lower()}",
tag=f"{'group_' if groups > 1 else ''}conv{dim}d_{data_layout.lower()}",
name=f"{'group_' if groups > 1 else ''}conv{dim}d_{data_layout.lower()}",
attrs={"layout_free_placeholders": [filt]} if auto_scheduler_should_rewrite_layout else {},
varargs_names=list(np.array(["nn", "ff", "yy", "xx", "zz"])[permutation_from]),
varargs_names=list(np.array(["nn", "ff", "yy", "xx", "zz"])[data_permutation_from]),
)
# if we used autoscheduler's changed layout we need to rewrite the ordering
# of the output dimensions
Expand Down Expand Up @@ -924,7 +943,7 @@ def group_conv2d_nhwc(Input, Filter, stride, padding, dilation, groups, out_dtyp
Output : tvm.te.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""
return conv(Input, Filter, stride, padding, dilation, groups, "NHWC", out_dtype=out_dtype)
return conv(Input, Filter, stride, padding, dilation, groups, "NHWC", "", out_dtype=out_dtype)


def unpack_NCHWc_to_nchw(packed_out, out_dtype):
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/topi/nn/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def conv3d_ncdhw(Input, Filter, stride, padding, dilation, groups, out_dtype=Non
Output : tvm.te.Tensor
5-D with shape [batch, out_channel, out_depth, out_height, out_width]
"""
return conv(Input, Filter, stride, padding, dilation, groups, "NCDHW", out_dtype)
return conv(Input, Filter, stride, padding, dilation, groups, "NCDHW", "", out_dtype)


def conv3d_ndhwc(
Expand Down Expand Up @@ -111,6 +111,7 @@ def conv3d_ndhwc(
dilation,
groups,
"NDHWC",
"",
out_dtype,
auto_scheduler_rewritten_layout,
meta_schedule_origin_shape,
Expand Down

0 comments on commit b444f70

Please sign in to comment.