diff --git a/myia/compile/backends/pytorch.py b/myia/compile/backends/pytorch.py index 5bfbafa3d..c3f2fbfec 100644 --- a/myia/compile/backends/pytorch.py +++ b/myia/compile/backends/pytorch.py @@ -381,9 +381,7 @@ def _impl(input, weight, stride, padding, dilation, groups): def pytorch_conv_transpose2d(op): """Implementation of conv_transpose2d.""" - def _impl( - input, weight, bias, stride, padding, output_padding, groups, dilation - ): + def _impl(input, weight, stride, padding, output_padding, groups, dilation): stride = tuple(_x.item() for _x in stride) padding = tuple(_x.item() for _x in padding) output_padding = tuple(_x.item() for _x in output_padding) @@ -392,7 +390,7 @@ def _impl( return torch.conv_transpose2d( input, weight, - bias, + None, stride, padding, output_padding, diff --git a/myia/compile/backends/relay.py b/myia/compile/backends/relay.py index 3b9da4c21..784b21b7e 100644 --- a/myia/compile/backends/relay.py +++ b/myia/compile/backends/relay.py @@ -314,6 +314,9 @@ def relay_conv2d(c, img, w, stride, pad, dil, groups): def relay_conv2d_weight_grad(c, data, wsize, dout, stride, pad, dil, groups): + # This implementation should match the one in pytorch backend + # (myia.compile.backends.pytorch_conv_grad.conv2d_weight) + assert wsize.is_constant(tuple) assert stride.is_constant(tuple) assert pad.is_constant(tuple) @@ -322,7 +325,7 @@ def relay_conv2d_weight_grad(c, data, wsize, dout, stride, pad, dil, groups): batch, in_channel, in_h, in_w = data.abstract.xshape() out_channel, _, filter_h, filter_w = wsize.value - _, _, grad_h, grad_w = dout.abstract.xshape() + grad_sh0, grad_sh1, grad_h, grad_w = dout.abstract.xshape() pad_h, pad_w = pad.value data = c.ref(data) @@ -354,18 +357,36 @@ def relay_conv2d_weight_grad(c, data, wsize, dout, stride, pad, dil, groups): dilation=stride.value, groups=batch * in_channel, ) + + conv_sh1 = grad_sh0 * grad_sh1 * (in_channel // groups.value) d = relay.reshape( d, - [ - batch, - in_channel // groups.value, - out_channel, - padded_weight_grad_h, - padded_weight_grad_w, - ], + [batch, conv_sh1 // batch, padded_weight_grad_h, padded_weight_grad_w], ) d = relay.sum(d, axis=0) - d = relay.transpose(d, [1, 0, 2, 3]) + + if groups.value > 1: + d = relay.reshape( + d, + [ + grad_sh1, + in_channel // groups.value, + padded_weight_grad_h, + padded_weight_grad_w, + ], + ) + else: + d = relay.reshape( + d, + [ + in_channel // groups.value, + grad_sh1, + padded_weight_grad_h, + padded_weight_grad_w, + ], + ) + d = relay.transpose(d, [1, 0, 2, 3]) + if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w: d = relay.strided_slice( d, begin=[0, 0, 0, 0], end=[None, None, filter_h, filter_w] @@ -374,13 +395,17 @@ def relay_conv2d_weight_grad(c, data, wsize, dout, stride, pad, dil, groups): def relay_conv_transpose2d( - c, input, weight, bias, stride, padding, output_padding, groups, dilation + c, input, weight, stride, padding, output_padding, groups, dilation ): - - if not bias.is_constant(type(None)): - raise NotImplementedError( - "conv_transpose2d: bias not yet supported " "in relay backend." - ) + """Implement conv2d_transpose using 10 relay calls including conv2d. + + Support all values for groups, dilation, strides, padding and + output padding. + Based on Theano implementation (2020/04/14): + https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/abstract_conv.py#L2927 + Need implementation of operation relay.nn.dilate + in TVM relay backend + """ assert stride.is_constant(tuple) assert padding.is_constant(tuple) @@ -389,23 +414,73 @@ def relay_conv_transpose2d( assert groups.is_constant(int) data_shape = input.abstract.xshape() - weight_shape = weight.abstract.xshape() - _, in_channels, _, _ = data_shape - _, w_c, filter_h, filter_w = weight_shape - _i = c.ref(input) - _w = c.ref(weight) - return relay.nn.conv2d_transpose( - _i, - _w, - strides=stride.value, - padding=padding.value, - dilation=dilation.value, - groups=groups.value, - output_padding=output_padding.value, - kernel_size=(filter_h, filter_w), - channels=in_channels, + kern_shape = weight.abstract.xshape() + h_in, w_in = data_shape[2:] + filter_h, filter_w = kern_shape[2:] + strides = stride.value + padding = padding.value + dilation = dilation.value + output_padding = output_padding.value + groups = groups.value + data = c.ref(input) + weight = c.ref(weight) + + h_out = ( + (h_in - 1) * strides[0] + - 2 * padding[0] + + dilation[0] * (filter_h - 1) + + output_padding[0] + + 1 + ) + w_out = ( + (w_in - 1) * strides[1] + - 2 * padding[1] + + dilation[1] * (filter_w - 1) + + output_padding[1] + + 1 + ) + + data_dilated = relay.nn.dilate(data, (1, 1) + strides) + data_padded = relay.nn.pad( + data_dilated, + ((0, 0), (0, 0), (0, output_padding[0]), (0, output_padding[1]),), ) + # Pre-process kernel, + # from (m0, m1, m2, m3) to (m1 * g, m0 // g, m2, m3). + mshp0 = kern_shape[0] // groups + c_out = kern_shape[1] * groups + kern = relay.reshape(weight, (groups, mshp0) + kern_shape[1:]) + # => (g, m0 // g, m1, m2, m3) + kern = relay.op.transpose(kern, axes=(1, 0, 2, 3, 4)) + # => (m0 // g, g, m1, m2, m3) + kern = relay.reshape(kern, (mshp0, c_out, kern_shape[-2], kern_shape[-1])) + # => (m0 // g, m1 * g, m2, m3) + kern = relay.op.transpose(kern, (1, 0, 2, 3)) + # => (m1 * g, m0 // g, m2, m3) + # Kernel 2 latest dimensions must be flipped + kern = relay.op.transform.reverse(kern, 2) + kern = relay.op.transform.reverse(kern, 3) + # End pre-processing kernel. + + img = relay.nn.conv2d( + data_padded, + kern, + groups=groups, + channels=c_out, + padding=[(kern_shape[2 + i] - 1) * dilation[i] for i in range(2)], + dilation=dilation, + ) + + if any(p != 0 for p in padding): + img = relay.op.transform.strided_slice( + data=img, + begin=[0, 0, padding[0], padding[1]], + end=[None, None, h_out + padding[0], w_out + padding[1]], + ) + + return img + def relay_concat(c, x, dim): assert dim.is_constant(int) diff --git a/myia/operations/macro_conv2d_grad_input.py b/myia/operations/macro_conv2d_grad_input.py index e62de41b2..0b6f6acf6 100644 --- a/myia/operations/macro_conv2d_grad_input.py +++ b/myia/operations/macro_conv2d_grad_input.py @@ -67,7 +67,6 @@ async def conv2d_grad_input( P.conv_transpose2d, r_grad_output.node, r_weight.node, - Constant(None), r_stride.node, r_padding.node, Constant(grad_input_padding), diff --git a/myia/operations/prim_conv_transpose2d.py b/myia/operations/prim_conv_transpose2d.py index d8a8338be..ba7569959 100644 --- a/myia/operations/prim_conv_transpose2d.py +++ b/myia/operations/prim_conv_transpose2d.py @@ -18,7 +18,6 @@ async def infer_conv_transpose2d( engine, input: AbstractArray, weight: AbstractArray, - bias, # un-typed, because it may be either None or an abstract array. stride: AbstractTuple, padding: AbstractTuple, output_padding: AbstractTuple, diff --git a/myia/public_api.py b/myia/public_api.py index 74d7a2874..0c3f6fe55 100644 --- a/myia/public_api.py +++ b/myia/public_api.py @@ -312,9 +312,12 @@ def conv_transpose2d( dilation=1, ): """Map of Pytorch method torch.nn.functional.conv_transpose2d.""" - return P.conv_transpose2d( - input, weight, bias, stride, padding, output_padding, groups, dilation + ret = P.conv_transpose2d( + input, weight, stride, padding, output_padding, groups, dilation ) + if bias is not None: + ret = ret + reshape(bias, (1, bias.shape[0], 1, 1)) + return ret @core diff --git a/requirements-cpu.conda b/requirements-cpu.conda index d7fbe687e..02c0771f4 100644 --- a/requirements-cpu.conda +++ b/requirements-cpu.conda @@ -2,7 +2,7 @@ python>=3.7,<3.8a0 colorama numpy opt_einsum -abergeron::tvm==0.7dev1+0.* +abergeron::tvm==0.7dev1+1.* pytest pytest-cov yaml diff --git a/requirements-gpu.conda b/requirements-gpu.conda index 0115e9dcb..51a94fe8e 100644 --- a/requirements-gpu.conda +++ b/requirements-gpu.conda @@ -4,7 +4,7 @@ numpy opt_einsum cudatoolkit=10.0.* cudnn -abergeron::tvm==0.7dev1+0.* +abergeron::tvm==0.7dev1+1.* pytest pytest-cov yaml diff --git a/tests/frontends/test_all_ops.py b/tests/frontends/test_all_ops.py index b1e461de7..d27a86a1f 100644 --- a/tests/frontends/test_all_ops.py +++ b/tests/frontends/test_all_ops.py @@ -46,7 +46,34 @@ @eqtest.register def eqtest(t1: torch.Tensor, t2, rtol=1e-5, atol=1e-8, **kwargs): - return torch.allclose(t1, t2, equal_nan=True, atol=atol, rtol=rtol) + """ New version of eqtest using np.testing.assert_allclose. + If comparison fails, this version will raise an exception + and display a more informative log if comparison fail, + especially max absolute and relative difference. + """ + # Quick debug code to display mismatching values + shape = t1.shape + if len(shape) > 1: + x1 = t1.flatten() + x2 = t2.flatten() + s = x1.shape[0] + c = 0 + for i in range(s): + v1 = x1[i].item() + v2 = x2[i].item() + if abs(v1 - v2) > atol: + c += 1 + print("diff", c, i + 1, v1, v2) + # End debug code + + np.testing.assert_allclose( + t1.detach().numpy(), + t2.detach().numpy(), + rtol=rtol, + atol=atol, + verbose=True, + ) + return True @eqtest.register @@ -241,6 +268,7 @@ def _run( pipeline=standard_pipeline, backend=None, numpy_compat=True, + **kwargs, ): """Test a Myia function. @@ -290,7 +318,7 @@ def out(args): if result is None: result = fn(*args) - self.check(out, args, result) + self.check(out, args, result, **kwargs) if numpy_compat: args_torch = args @@ -476,7 +504,6 @@ def test_conv2d_no_dil_stride(inp, w): nn.Parameter(torch.randn(3, 2, 3, 3, dtype=torch.float32)), None, ), - backend=backend_no_relay, ) def test_torch_conv2d(inp, w, b): value = torch.nn.functional.conv2d(inp, w, b, (2, 3), (3, 2), (3, 4), 3) @@ -494,14 +521,13 @@ def test_torch_conv2d(inp, w, b): nn.Parameter(torch.randn(3, 2, 3, 3, dtype=torch.float32)), None, ), - backend=backend_no_relay, ) def test_torch_conv2d__non_tuple_args(inp, w, b): value = torch.nn.functional.conv2d(inp, w, b, 2, 3, 4, 3) return torch.sum(value) -@fwd_and_bwd_no_relay( +@fwd_and_bwd( nn.Parameter(torch.randn(2, 1, 4, 5, dtype=torch.float32)), nn.Parameter(torch.randn(3, 1, 3, 3, dtype=torch.float32)), nn.Parameter(torch.randn(3, dtype=torch.float32)), @@ -512,7 +538,8 @@ def test_torch_conv2d__group3(inp, w, b): @mt( - run_no_relay( + # with bias + run( torch.randn(1, 2, 4, 4), torch.randn(2, 3, 2, 2), torch.randn(6), @@ -522,7 +549,30 @@ def test_torch_conv2d__group3(inp, w, b): 2, (1, 1), ), - run_no_relay( + # no bias + run( + torch.randn(1, 2, 4, 4), + torch.randn(2, 3, 2, 2), + None, + (1, 1), + (1, 1), + (0, 0), + 2, + (1, 1), + ), + # with bias + run( + torch.randn(1, 2, 5, 4), + torch.randn(2, 4, 1, 3), + torch.randn(8), + (2, 3), + (4, 5), + (3, 2), + 2, + (5, 4), + ), + # no bias + run( torch.randn(1, 2, 5, 4), torch.randn(2, 4, 1, 3), None, @@ -532,7 +582,19 @@ def test_torch_conv2d__group3(inp, w, b): 2, (5, 4), ), - run_no_relay( + # with bias + run( + torch.randn(5, 2, 5, 6), + torch.randn(2, 2, 4, 4), + torch.randn(2), + (1, 1), + (0, 0), + (0, 0), + 1, + (1, 1), + ), + # no bias + run( torch.randn(5, 2, 5, 6), torch.randn(2, 2, 4, 4), None, @@ -542,7 +604,19 @@ def test_torch_conv2d__group3(inp, w, b): 1, (1, 1), ), - run_no_relay( + # with bias + run( + torch.randn(1, 1, 4, 4), + torch.randn(1, 3, 2, 2), + torch.randn(3), + (1, 1), + (1, 1), + (0, 0), + 1, + (1, 1), + ), + # no bias + run( torch.randn(1, 1, 4, 4), torch.randn(1, 3, 2, 2), None, @@ -553,6 +627,7 @@ def test_torch_conv2d__group3(inp, w, b): (1, 1), ), broad_specs=(True, True, False, False, False, False, False, False), + atol=1e-5, ) def test_torch_conv_transpose2d(i, w, b, s, p, o_p, g, d): return torch.nn.functional.conv_transpose2d(i, w, b, s, p, o_p, g, d) diff --git a/tests/multitest.py b/tests/multitest.py index c87ccbe4b..8ef430fb7 100644 --- a/tests/multitest.py +++ b/tests/multitest.py @@ -91,7 +91,7 @@ def configure(self, **spec): """Configure this test with new kwargs.""" return MyiaFunctionTest(self.runtest, spec={**self.spec, **spec}) - def check(self, run, args, expected): + def check(self, run, args, expected, **kwargs): """Check the result of run() against expected. Expected can be either: @@ -118,7 +118,7 @@ def check(self, run, args, expected): if not expected(args, res): raise Exception(f"Failed the result check function") - elif not eqtest(res, expected): + elif not eqtest(res, expected, **kwargs): raise Exception(f"Mismatch: expected {expected}, got {res}") def generate_params(self):