Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(wip)(do not merge) Run all tests on relay #338

Closed
wants to merge 11 commits into from
6 changes: 2 additions & 4 deletions myia/compile/backends/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,7 @@ def _impl(input, weight, stride, padding, dilation, groups):
def pytorch_conv_transpose2d(op):
"""Implementation of conv_transpose2d."""

def _impl(
input, weight, bias, stride, padding, output_padding, groups, dilation
):
def _impl(input, weight, stride, padding, output_padding, groups, dilation):
stride = tuple(_x.item() for _x in stride)
padding = tuple(_x.item() for _x in padding)
output_padding = tuple(_x.item() for _x in output_padding)
Expand All @@ -392,7 +390,7 @@ def _impl(
return torch.conv_transpose2d(
input,
weight,
bias,
None,
stride,
padding,
output_padding,
Expand Down
19 changes: 11 additions & 8 deletions myia/compile/backends/relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,26 +374,29 @@ def relay_conv2d_weight_grad(c, data, wsize, dout, stride, pad, dil, groups):


def relay_conv_transpose2d(
c, input, weight, bias, stride, padding, output_padding, groups, dilation
c, input, weight, stride, padding, output_padding, groups, dilation
):

if not bias.is_constant(type(None)):
raise NotImplementedError(
"conv_transpose2d: bias not yet supported " "in relay backend."
)

assert stride.is_constant(tuple)
assert padding.is_constant(tuple)
assert output_padding.is_constant(tuple)
assert dilation.is_constant(tuple)
assert groups.is_constant(int)

# Only groups==1 and dilation==(1, 1) is supported, on both CPU and GPU.
if groups.value != 1:
raise RuntimeError(f"Only support groups=1, got {groups.value}")
if dilation.value != (1, 1):
raise RuntimeError(
f"Only support dilation=(1, 1), got {dilation.value}"
)

data_shape = input.abstract.xshape()
weight_shape = weight.abstract.xshape()
_, in_channels, _, _ = data_shape
_, w_c, filter_h, filter_w = weight_shape
_i = c.ref(input)
_w = c.ref(weight)
output_channels = w_c * groups.value
return relay.nn.conv2d_transpose(
_i,
_w,
Expand All @@ -403,7 +406,7 @@ def relay_conv_transpose2d(
groups=groups.value,
output_padding=output_padding.value,
kernel_size=(filter_h, filter_w),
channels=in_channels,
channels=output_channels,
)


Expand Down
1 change: 0 additions & 1 deletion myia/operations/macro_conv2d_grad_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ async def conv2d_grad_input(
P.conv_transpose2d,
r_grad_output.node,
r_weight.node,
Constant(None),
r_stride.node,
r_padding.node,
Constant(grad_input_padding),
Expand Down
1 change: 0 additions & 1 deletion myia/operations/prim_conv_transpose2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ async def infer_conv_transpose2d(
engine,
input: AbstractArray,
weight: AbstractArray,
bias, # un-typed, because it may be either None or an abstract array.
stride: AbstractTuple,
padding: AbstractTuple,
output_padding: AbstractTuple,
Expand Down
2 changes: 1 addition & 1 deletion myia/public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def conv_transpose2d(
):
"""Map of Pytorch method torch.nn.functional.conv_transpose2d."""
return P.conv_transpose2d(
input, weight, bias, stride, padding, output_padding, groups, dilation
input, weight, stride, padding, output_padding, groups, dilation
)


Expand Down
Loading