Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add relay implementation to fully support conv2d_transpose in relay backend #346

Merged
merged 4 commits into from
Apr 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions myia/compile/backends/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,7 @@ def _impl(input, weight, stride, padding, dilation, groups):
def pytorch_conv_transpose2d(op):
"""Implementation of conv_transpose2d."""

def _impl(
input, weight, bias, stride, padding, output_padding, groups, dilation
):
def _impl(input, weight, stride, padding, output_padding, groups, dilation):
stride = tuple(_x.item() for _x in stride)
padding = tuple(_x.item() for _x in padding)
output_padding = tuple(_x.item() for _x in output_padding)
Expand All @@ -392,7 +390,7 @@ def _impl(
return torch.conv_transpose2d(
input,
weight,
bias,
None,
stride,
padding,
output_padding,
Expand Down
135 changes: 105 additions & 30 deletions myia/compile/backends/relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ def relay_conv2d(c, img, w, stride, pad, dil, groups):


def relay_conv2d_weight_grad(c, data, wsize, dout, stride, pad, dil, groups):
# This implementation should match the one in pytorch backend
# (myia.compile.backends.pytorch_conv_grad.conv2d_weight)

assert wsize.is_constant(tuple)
assert stride.is_constant(tuple)
assert pad.is_constant(tuple)
Expand All @@ -322,7 +325,7 @@ def relay_conv2d_weight_grad(c, data, wsize, dout, stride, pad, dil, groups):

batch, in_channel, in_h, in_w = data.abstract.xshape()
out_channel, _, filter_h, filter_w = wsize.value
_, _, grad_h, grad_w = dout.abstract.xshape()
grad_sh0, grad_sh1, grad_h, grad_w = dout.abstract.xshape()
pad_h, pad_w = pad.value

data = c.ref(data)
Expand Down Expand Up @@ -354,18 +357,36 @@ def relay_conv2d_weight_grad(c, data, wsize, dout, stride, pad, dil, groups):
dilation=stride.value,
groups=batch * in_channel,
)

conv_sh1 = grad_sh0 * grad_sh1 * (in_channel // groups.value)
d = relay.reshape(
d,
[
batch,
in_channel // groups.value,
out_channel,
padded_weight_grad_h,
padded_weight_grad_w,
],
[batch, conv_sh1 // batch, padded_weight_grad_h, padded_weight_grad_w],
)
d = relay.sum(d, axis=0)
d = relay.transpose(d, [1, 0, 2, 3])

if groups.value > 1:
d = relay.reshape(
d,
[
grad_sh1,
in_channel // groups.value,
padded_weight_grad_h,
padded_weight_grad_w,
],
)
else:
d = relay.reshape(
d,
[
in_channel // groups.value,
grad_sh1,
padded_weight_grad_h,
padded_weight_grad_w,
],
)
d = relay.transpose(d, [1, 0, 2, 3])

if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
d = relay.strided_slice(
d, begin=[0, 0, 0, 0], end=[None, None, filter_h, filter_w]
Expand All @@ -374,13 +395,17 @@ def relay_conv2d_weight_grad(c, data, wsize, dout, stride, pad, dil, groups):


def relay_conv_transpose2d(
c, input, weight, bias, stride, padding, output_padding, groups, dilation
c, input, weight, stride, padding, output_padding, groups, dilation
):

if not bias.is_constant(type(None)):
raise NotImplementedError(
"conv_transpose2d: bias not yet supported " "in relay backend."
)
"""Implement conv2d_transpose using 10 relay calls including conv2d.

Support all values for groups, dilation, strides, padding and
output padding.
Based on Theano implementation (2020/04/14):
https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/abstract_conv.py#L2927
Need implementation of operation relay.nn.dilate
in TVM relay backend
"""

assert stride.is_constant(tuple)
assert padding.is_constant(tuple)
Expand All @@ -389,23 +414,73 @@ def relay_conv_transpose2d(
assert groups.is_constant(int)

data_shape = input.abstract.xshape()
weight_shape = weight.abstract.xshape()
_, in_channels, _, _ = data_shape
_, w_c, filter_h, filter_w = weight_shape
_i = c.ref(input)
_w = c.ref(weight)
return relay.nn.conv2d_transpose(
_i,
_w,
strides=stride.value,
padding=padding.value,
dilation=dilation.value,
groups=groups.value,
output_padding=output_padding.value,
kernel_size=(filter_h, filter_w),
channels=in_channels,
kern_shape = weight.abstract.xshape()
h_in, w_in = data_shape[2:]
filter_h, filter_w = kern_shape[2:]
strides = stride.value
padding = padding.value
dilation = dilation.value
output_padding = output_padding.value
groups = groups.value
data = c.ref(input)
weight = c.ref(weight)

h_out = (
(h_in - 1) * strides[0]
- 2 * padding[0]
+ dilation[0] * (filter_h - 1)
+ output_padding[0]
+ 1
)
w_out = (
(w_in - 1) * strides[1]
- 2 * padding[1]
+ dilation[1] * (filter_w - 1)
+ output_padding[1]
+ 1
)

data_dilated = relay.nn.dilate(data, (1, 1) + strides)
data_padded = relay.nn.pad(
data_dilated,
((0, 0), (0, 0), (0, output_padding[0]), (0, output_padding[1]),),
)

# Pre-process kernel,
# from (m0, m1, m2, m3) to (m1 * g, m0 // g, m2, m3).
mshp0 = kern_shape[0] // groups
c_out = kern_shape[1] * groups
kern = relay.reshape(weight, (groups, mshp0) + kern_shape[1:])
# => (g, m0 // g, m1, m2, m3)
kern = relay.op.transpose(kern, axes=(1, 0, 2, 3, 4))
# => (m0 // g, g, m1, m2, m3)
kern = relay.reshape(kern, (mshp0, c_out, kern_shape[-2], kern_shape[-1]))
# => (m0 // g, m1 * g, m2, m3)
kern = relay.op.transpose(kern, (1, 0, 2, 3))
# => (m1 * g, m0 // g, m2, m3)
# Kernel 2 latest dimensions must be flipped
kern = relay.op.transform.reverse(kern, 2)
kern = relay.op.transform.reverse(kern, 3)
# End pre-processing kernel.

img = relay.nn.conv2d(
data_padded,
kern,
groups=groups,
channels=c_out,
padding=[(kern_shape[2 + i] - 1) * dilation[i] for i in range(2)],
dilation=dilation,
)

if any(p != 0 for p in padding):
img = relay.op.transform.strided_slice(
data=img,
begin=[0, 0, padding[0], padding[1]],
end=[None, None, h_out + padding[0], w_out + padding[1]],
)

return img


def relay_concat(c, x, dim):
assert dim.is_constant(int)
Expand Down
1 change: 0 additions & 1 deletion myia/operations/macro_conv2d_grad_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ async def conv2d_grad_input(
P.conv_transpose2d,
r_grad_output.node,
r_weight.node,
Constant(None),
r_stride.node,
r_padding.node,
Constant(grad_input_padding),
Expand Down
1 change: 0 additions & 1 deletion myia/operations/prim_conv_transpose2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ async def infer_conv_transpose2d(
engine,
input: AbstractArray,
weight: AbstractArray,
bias, # un-typed, because it may be either None or an abstract array.
stride: AbstractTuple,
padding: AbstractTuple,
output_padding: AbstractTuple,
Expand Down
7 changes: 5 additions & 2 deletions myia/public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,12 @@ def conv_transpose2d(
dilation=1,
):
"""Map of Pytorch method torch.nn.functional.conv_transpose2d."""
return P.conv_transpose2d(
input, weight, bias, stride, padding, output_padding, groups, dilation
ret = P.conv_transpose2d(
input, weight, stride, padding, output_padding, groups, dilation
)
if bias is not None:
ret = ret + reshape(bias, (1, bias.shape[0], 1, 1))
return ret


@core
Expand Down
2 changes: 1 addition & 1 deletion requirements-cpu.conda
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ python>=3.7,<3.8a0
colorama
numpy
opt_einsum
abergeron::tvm==0.7dev1+0.*
abergeron::tvm==0.7dev1+1.*
pytest
pytest-cov
yaml
Expand Down
2 changes: 1 addition & 1 deletion requirements-gpu.conda
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ numpy
opt_einsum
cudatoolkit=10.0.*
cudnn
abergeron::tvm==0.7dev1+0.*
abergeron::tvm==0.7dev1+1.*
pytest
pytest-cov
yaml
Expand Down
Loading