diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 1499de1f6..b1a5a30ae 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -10121,6 +10121,9 @@ ], "kwargs_change": { "dim": "axis" + }, + "paddle_default_kwargs": { + "axis": 0 } }, "torch.nn.MSELoss": { @@ -10771,9 +10774,6 @@ "bias_hh_attr" ] }, - "unsupport_args": [ - "proj_size" - ], "min_input_args": 3 }, "torch.nn.RNNCell": { @@ -10952,6 +10952,9 @@ ], "kwargs_change": { "dim": "axis" + }, + "paddle_default_kwargs": { + "axis": 0 } }, "torch.nn.Softmax2d": { diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index c8ab13544..27478982d 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -363,7 +363,7 @@ def get_paddle_class_attribute_nodes(self, node): def get_paddle_nodes(self, args, kwargs): new_args = self.parse_args(args) new_kwargs = self.parse_kwargs(kwargs) - if new_kwargs is not None and new_args is not None: + if new_kwargs is not None: code = "{}({})".format( self.get_paddle_api(), self.args_and_kwargs_to_str(new_args, new_kwargs) ) @@ -3829,8 +3829,8 @@ def generate_code(self, kwargs): class SoftmaxMatcher(BaseMatcher): def generate_code(self, kwargs): if "dim" not in kwargs or "None" in kwargs["dim"]: - return None - + kwargs.pop("dim", "None") + kwargs["axis"] = 0 return GenericMatcher.generate_code(self, kwargs) diff --git a/tests/test_max_pool2d.py b/tests/test_max_pool2d.py index 9ee2b5ca6..051719689 100644 --- a/tests/test_max_pool2d.py +++ b/tests/test_max_pool2d.py @@ -104,3 +104,60 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857], + [-1.2533, -0.9829, -1.0981], + [ 0.1507, -1.1431, -2.0361]], + + [[ 0.1024, -0.4482, 0.4137], + [ 0.9385, 0.4565, 0.7702], + [ 0.4135, -0.2587, 0.0482]]]]) + result = torch.max_pool2d(input=input, kernel_size=3, stride=(2, 1), padding=1, dilation=1, ceil_mode=True) + """ + ) + obj.run( + pytorch_code, ["result"], unsupport=True, reason="dilation is not supported now" + ) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857], + [-1.2533, -0.9829, -1.0981], + [ 0.1507, -1.1431, -2.0361]], + + [[ 0.1024, -0.4482, 0.4137], + [ 0.9385, 0.4565, 0.7702], + [ 0.4135, -0.2587, 0.0482]]]]) + result = torch.max_pool2d(input=input, stride=(2, 1), kernel_size=3, dilation=1, padding=1, ceil_mode=True) + """ + ) + obj.run( + pytorch_code, ["result"], unsupport=True, reason="dilation is not supported now" + ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857], + [-1.2533, -0.9829, -1.0981], + [ 0.1507, -1.1431, -2.0361]], + + [[ 0.1024, -0.4482, 0.4137], + [ 0.9385, 0.4565, 0.7702], + [ 0.4135, -0.2587, 0.0482]]]]) + result = torch.max_pool2d(input, 3, (2, 1), 1, 1, True) + """ + ) + obj.run( + pytorch_code, ["result"], unsupport=True, reason="dilation is not supported now" + ) diff --git a/tests/test_max_pool3d.py b/tests/test_max_pool3d.py index 94977746f..246524e8f 100644 --- a/tests/test_max_pool3d.py +++ b/tests/test_max_pool3d.py @@ -79,3 +79,29 @@ def test_case_5(): check_dtype=False, reason="torch indices dtype is int64, while paddle is int32", ) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.arange(4800, dtype=torch.float32).reshape(2, 3, 8, 10, 10) + result = torch.max_pool3d(input=input, kernel_size=(2, 2, 2), stride=(2, 1, 1), padding=1, dilation=1, ceil_mode=True) + """ + ) + obj.run( + pytorch_code, ["result"], unsupport=True, reason="dilation is not suppored now" + ) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.arange(4800, dtype=torch.float32).reshape(2, 3, 8, 10, 10) + result = torch.max_pool3d(input=input, dilation=1, kernel_size=(2, 2, 2), padding=1, stride=(2, 1, 1), ceil_mode=True) + """ + ) + obj.run( + pytorch_code, ["result"], unsupport=True, reason="dilation is not suppored now" + ) diff --git a/tests/test_mean.py b/tests/test_mean.py index 999326120..20cf20c72 100644 --- a/tests/test_mean.py +++ b/tests/test_mean.py @@ -85,3 +85,27 @@ def test_case_6(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]]) + out = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]], dtype=torch.float64) + result = torch.mean(input=input, dim=1, keepdim=True, dtype=torch.float64, out=out) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]]) + out = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]], dtype=torch.float64) + result = torch.mean(input=input, keepdim=True, dim=1, out=out, dtype=torch.float64) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_min.py b/tests/test_min.py index 5386a1940..74bd10387 100644 --- a/tests/test_min.py +++ b/tests/test_min.py @@ -150,3 +150,27 @@ def test_case_12(): """ ) obj.run(pytorch_code, ["result", "out"]) + + +def test_case_13(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, 2, 1], [3, 4, 6]]) + out = [torch.tensor(0), torch.tensor(1)] + result = torch.min(input=x, dim=1, keepdim=False, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_14(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, 2, 1], [3, 4, 6]]) + out = [torch.tensor(0), torch.tensor(1)] + result = torch.min(input=x, keepdim=False, dim=1, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) diff --git a/tests/test_nn_ConvTranspose1d.py b/tests/test_nn_ConvTranspose1d.py index e1cabe6ef..276dc7258 100644 --- a/tests/test_nn_ConvTranspose1d.py +++ b/tests/test_nn_ConvTranspose1d.py @@ -76,8 +76,30 @@ def test_case_5(): """ import torch import torch.nn as nn + + x = torch.randn(5, 16, 50) + model = nn.ConvTranspose1d(in_channels=16, out_channels=33, kernel_size=5, stride=1, padding=4, output_padding=0, groups=1, bias=True, dilation=3, + padding_mode='zeros', device=None, dtype=None) + result = model(x) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle does not support parameter of padding_mode", + ) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.randn(5, 16, 50) - model = nn.ConvTranspose1d(16, 33, 5, stride=1, padding=4, dilation=3, bias=True, padding_mode='zeros') + model = nn.ConvTranspose1d(in_channels=16, kernel_size=5, out_channels=33, stride=1, padding=4, device=None, groups=1, bias=True, output_padding=0, dilation=3, + padding_mode='zeros', dtype=None) result = model(x) """ ) @@ -87,3 +109,36 @@ def test_case_5(): unsupport=True, reason="Paddle does not support parameter of padding_mode", ) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + + x = torch.randn(5, 16, 50) + model = nn.ConvTranspose1d(16, 33, 5, 1, 4, 0, 1, True, 3, 'zeros', None, None) + result = model(x) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle does not support parameter of padding_mode", + ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + + x = torch.randn(5, 16, 50) + model = nn.ConvTranspose1d(16, 33, 5) + result = model(x) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) diff --git a/tests/test_nn_ConvTranspose2d.py b/tests/test_nn_ConvTranspose2d.py index e0e0f8d15..7d7d37837 100644 --- a/tests/test_nn_ConvTranspose2d.py +++ b/tests/test_nn_ConvTranspose2d.py @@ -87,3 +87,88 @@ def test_case_5(): unsupport=True, reason="Paddle does not support parameter of padding_mode", ) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.randn(5, 16, 50, 100) + model = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1), bias=True, padding_mode='zeros') + result = model(x) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle does not support parameter of padding_mode", + ) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.randn(5, 16, 50, 100) + model = nn.ConvTranspose2d(in_channels=16, out_channels=33, kernel_size=(3, 5), stride=(2, 1), padding=(4, 2), output_padding=0, groups=1, bias=True, dilation=(3, 1), padding_mode='zeros', device=None, dtype=None) + result = model(x) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle does not support parameter of padding_mode", + ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.randn(5, 16, 50, 100) + model = nn.ConvTranspose2d(in_channels=16, output_padding=0, kernel_size=(3, 5), stride=(2, 1), padding=(4, 2), groups=1, bias=True, out_channels=33, dilation=(3, 1), padding_mode='zeros', device=None, dtype=None) + result = model(x) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle does not support parameter of padding_mode", + ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.randn(5, 16, 50, 100) + model = nn.ConvTranspose2d(16, 33, (3, 5), (2, 1), (4, 2), 0, 1, True, (3, 1), 'zeros', None, None) + result = model(x) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle does not support parameter of padding_mode", + ) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.randn(5, 16, 50, 100) + model = nn.ConvTranspose2d(16, 33, (3, 5)) + result = model(x) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) diff --git a/tests/test_nn_ConvTranspose3d.py b/tests/test_nn_ConvTranspose3d.py index 22a01363f..cd54fd701 100644 --- a/tests/test_nn_ConvTranspose3d.py +++ b/tests/test_nn_ConvTranspose3d.py @@ -87,3 +87,70 @@ def test_case_5(): unsupport=True, reason="Paddle does not support parameter of padding_mode", ) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.randn(5, 16, 50, 20, 20) + model = nn.ConvTranspose3d(in_channels=16, out_channels=33, kernel_size=(3, 3, 5), stride=(2, 2, 1), padding=(4, 2, 2), output_padding=0, groups=1, bias=True, dilation=(3, 1, 1), padding_mode='zeros', device=None, dtype=None) + result = model(x) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle does not support parameter of padding_mode", + ) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.randn(5, 16, 50, 20, 20) + model = nn.ConvTranspose3d(in_channels=16, kernel_size=(3, 3, 5), out_channels=33, stride=(2, 2, 1), device=None, padding=(4, 2, 2), bias=True, output_padding=0, groups=1, dilation=(3, 1, 1), padding_mode='zeros', dtype=None) + result = model(x) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle does not support parameter of padding_mode", + ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.randn(5, 16, 50, 20, 20) + model = nn.ConvTranspose3d(16, 33, (3, 3, 5), (2, 2, 1), (4, 2, 2), 0, 1, True, (3, 1, 1), 'zeros', None, None) + result = model(x) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle does not support parameter of padding_mode", + ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.randn(5, 16, 50, 20, 20) + model = nn.ConvTranspose3d(16, 33, (3, 3, 5)) + result = model(x) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) diff --git a/tests/test_nn_Embedding.py b/tests/test_nn_Embedding.py index 748bbefdb..4a195b3f1 100644 --- a/tests/test_nn_Embedding.py +++ b/tests/test_nn_Embedding.py @@ -80,7 +80,11 @@ def test_case_3(): result = embedding(x) """ ) - obj.run(pytorch_code, unsupport=True, reason="paddle unsupport") + obj.run( + pytorch_code, + unsupport=True, + reason="Paddle does not support parameter of max_norm", + ) def test_case_4(): @@ -88,7 +92,39 @@ def test_case_4(): """ import torch padding_idx = 0 - embedding = torch.nn.Embedding(4, 3,padding_idx=padding_idx,max_norm=2.0) + embedding = torch.nn.Embedding(num_embeddings=4, embedding_dim=3, padding_idx=padding_idx, max_norm=2.0, norm_type=2.0, scale_grad_by_freq=False, sparse=False) + result = embedding.padding_idx + """ + ) + obj.run( + pytorch_code, + unsupport=True, + reason="Paddle does not support parameter of max_norm", + ) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + padding_idx = 0 + embedding = torch.nn.Embedding(num_embeddings=4, embedding_dim=3, scale_grad_by_freq=False, max_norm=2.0, norm_type=2.0, padding_idx=padding_idx, sparse=False) + result = embedding.padding_idx + """ + ) + obj.run( + pytorch_code, + unsupport=True, + reason="Paddle does not support parameter of max_norm", + ) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + padding_idx = 0 + embedding = torch.nn.Embedding(num_embeddings=4, embedding_dim=3, padding_idx=padding_idx, norm_type=2.0, scale_grad_by_freq=False, sparse=False) result = embedding.padding_idx """ ) diff --git a/tests/test_nn_Identity.py b/tests/test_nn_Identity.py index 438842966..b6198fbf4 100644 --- a/tests/test_nn_Identity.py +++ b/tests/test_nn_Identity.py @@ -30,3 +30,17 @@ def test_case_1(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False) + input = torch.ones(128, 20) + output = m(input) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_LogSoftmax.py b/tests/test_nn_LogSoftmax.py index 31e64aee3..794d8905a 100644 --- a/tests/test_nn_LogSoftmax.py +++ b/tests/test_nn_LogSoftmax.py @@ -70,9 +70,22 @@ def test_case_3(): result = model(x) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="When dim is None, paddle and pytorch generate different results", + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]]) + model = nn.LogSoftmax(dim=None) + result = model(x) + """ ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_MaxPool1d.py b/tests/test_nn_MaxPool1d.py index 26f9d83e8..82b74fcfc 100644 --- a/tests/test_nn_MaxPool1d.py +++ b/tests/test_nn_MaxPool1d.py @@ -100,3 +100,57 @@ def test_case_6(): unsupport=True, reason="paddle.nn.MaxPool1D dose not support 'dilation' now!", ) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[0.1, 1., 2., 3.], [4., 5., 6., 7.]]]) + model = nn.MaxPool1d(kernel_size=2, stride=1, padding=1, dilation=2, return_indices=True, ceil_mode=False) + result, indices = model(x) + """ + ) + obj.run( + pytorch_code, + ["result", "indices"], + unsupport=True, + reason="paddle.nn.MaxPool1D dose not support 'dilation' now!", + ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[0.1, 1., 2., 3.], [4., 5., 6., 7.]]]) + model = nn.MaxPool1d(kernel_size=2, dilation=2, padding=1, return_indices=True, stride=1, ceil_mode=False) + result, indices = model(x) + """ + ) + obj.run( + pytorch_code, + ["result", "indices"], + unsupport=True, + reason="paddle.nn.MaxPool1D dose not support 'dilation' now!", + ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[0.1, 1., 2., 3.], [4., 5., 6., 7.]]]) + model = nn.MaxPool1d(2, 1, 1, 2, True, False) + result, indices = model(x) + """ + ) + obj.run( + pytorch_code, + ["result", "indices"], + unsupport=True, + reason="paddle.nn.MaxPool1D dose not support 'dilation' now!", + ) diff --git a/tests/test_nn_MaxPool2d.py b/tests/test_nn_MaxPool2d.py index 5a533e29f..5ba80363f 100644 --- a/tests/test_nn_MaxPool2d.py +++ b/tests/test_nn_MaxPool2d.py @@ -101,3 +101,57 @@ def test_case_6(): unsupport=True, reason="paddle.nn.MaxPool2D dose not support 'dilation' now!", ) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[[0., 1., 2., 3.], [4., 5., 6., 7.]]]]) + model = nn.MaxPool2d(kernel_size=2, stride=1, padding=1, dilation=2, return_indices=True, ceil_mode=True) + result, indices = model(x) + """ + ) + obj.run( + pytorch_code, + ["result", "indices"], + unsupport=True, + reason="paddle.nn.MaxPool2D dose not support 'dilation' now!", + ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[[0., 1., 2., 3.], [4., 5., 6., 7.]]]]) + model = nn.MaxPool2d(kernel_size=2, dilation=2, stride=1, return_indices=True, padding=1, ceil_mode=True) + result, indices = model(x) + """ + ) + obj.run( + pytorch_code, + ["result", "indices"], + unsupport=True, + reason="paddle.nn.MaxPool2D dose not support 'dilation' now!", + ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[[0., 1., 2., 3.], [4., 5., 6., 7.]]]]) + model = nn.MaxPool2d(2, 1, 1, 2, True, True) + result, indices = model(x) + """ + ) + obj.run( + pytorch_code, + ["result", "indices"], + unsupport=True, + reason="paddle.nn.MaxPool2D dose not support 'dilation' now!", + ) diff --git a/tests/test_nn_MaxPool3d.py b/tests/test_nn_MaxPool3d.py index 9947018e9..bf6ab96cb 100644 --- a/tests/test_nn_MaxPool3d.py +++ b/tests/test_nn_MaxPool3d.py @@ -160,3 +160,87 @@ def test_case_6(): unsupport=True, reason="paddle.nn.MaxPool3D dose not support 'dilation' now!", ) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[[[-0.8658, 1.0869, -2.1977], + [-2.1073, 1.0974, -1.4485], + [ 0.5880, -0.7189, 0.1089]], + + [[ 1.3036, 0.3086, -1.2245], + [-0.6707, -0.0195, -0.1474], + [ 0.2727, -0.4938, -0.6854]], + + [[ 0.5525, 1.0111, -0.1847], + [ 0.1111, -0.6373, -0.2220], + [-0.5963, 0.7734, 0.0409]]]]]) + model = nn.MaxPool3d(kernel_size=2, stride=1, padding=1, dilation=2, return_indices=True, ceil_mode=True) + result, indices = model(x) + """ + ) + obj.run( + pytorch_code, + ["result", "indices"], + unsupport=True, + reason="paddle.nn.MaxPool3D dose not support 'dilation' now!", + ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[[[-0.8658, 1.0869, -2.1977], + [-2.1073, 1.0974, -1.4485], + [ 0.5880, -0.7189, 0.1089]], + + [[ 1.3036, 0.3086, -1.2245], + [-0.6707, -0.0195, -0.1474], + [ 0.2727, -0.4938, -0.6854]], + + [[ 0.5525, 1.0111, -0.1847], + [ 0.1111, -0.6373, -0.2220], + [-0.5963, 0.7734, 0.0409]]]]]) + model = nn.MaxPool3d(kernel_size=2, padding=1, stride=1, return_indices=True, dilation=2, ceil_mode=True) + result, indices = model(x) + """ + ) + obj.run( + pytorch_code, + ["result", "indices"], + unsupport=True, + reason="paddle.nn.MaxPool3D dose not support 'dilation' now!", + ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[[[-0.8658, 1.0869, -2.1977], + [-2.1073, 1.0974, -1.4485], + [ 0.5880, -0.7189, 0.1089]], + + [[ 1.3036, 0.3086, -1.2245], + [-0.6707, -0.0195, -0.1474], + [ 0.2727, -0.4938, -0.6854]], + + [[ 0.5525, 1.0111, -0.1847], + [ 0.1111, -0.6373, -0.2220], + [-0.5963, 0.7734, 0.0409]]]]]) + model = nn.MaxPool3d(2, 1, 1, True, 2, True) + result, indices = model(x) + """ + ) + obj.run( + pytorch_code, + ["result", "indices"], + unsupport=True, + reason="paddle.nn.MaxPool3D dose not support 'dilation' now!", + ) diff --git a/tests/test_nn_Module_register_forward_hook.py b/tests/test_nn_Module_register_forward_hook.py index d3f2c11b5..fa4869e54 100644 --- a/tests/test_nn_Module_register_forward_hook.py +++ b/tests/test_nn_Module_register_forward_hook.py @@ -107,3 +107,33 @@ def hook(module, fea_in, fea_out): unsupport=True, reason="prepend, with_kwargs and always_call is not supported", ) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + result = [] + class TestForHook(nn.Module): + def __init__(self): + super().__init__() + + self.linear_1 = nn.Linear(in_features=2, out_features=2) + def forward(self, x): + x1 = self.linear_1(x) + return x, x, x1 + def hook(module, fea_in, fea_out): + result.append(1) + + net = TestForHook() + net.register_forward_hook(prepend=False, hook=hook, always_call=False ,with_kwargs=False) + a = torch.tensor([0.,0.]) + net(a) + """ + ) + obj.run( + pytorch_code, + unsupport=True, + reason="prepend, with_kwargs and always_call is not supported", + ) diff --git a/tests/test_nn_Module_register_forward_pre_hook.py b/tests/test_nn_Module_register_forward_pre_hook.py index 664704eba..c4a611541 100644 --- a/tests/test_nn_Module_register_forward_pre_hook.py +++ b/tests/test_nn_Module_register_forward_pre_hook.py @@ -105,3 +105,31 @@ def hook(module, fea_in): obj.run( pytorch_code, unsupport=True, reason="prepend and with_kwargs is not supported" ) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + result = [] + class TestForHook(nn.Module): + def __init__(self): + super().__init__() + + self.linear_1 = nn.Linear(in_features=2, out_features=2) + def forward(self, x): + x1 = self.linear_1(x) + return x, x1 + def hook(module, fea_in): + result.append(1) + + net = TestForHook() + net.register_forward_pre_hook(hook=hook, with_kwargs=False, prepend=False) + a = torch.tensor([0.,0.]) + net(a) + """ + ) + obj.run( + pytorch_code, unsupport=True, reason="prepend and with_kwargs is not supported" + ) diff --git a/tests/test_nn_Softmax.py b/tests/test_nn_Softmax.py index b7279379d..5f4a49d01 100644 --- a/tests/test_nn_Softmax.py +++ b/tests/test_nn_Softmax.py @@ -70,9 +70,22 @@ def test_case_3(): result = model(x) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="When dim is None, paddle and pytorch generate different results", + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 10.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]]) + model = nn.Softmax(dim=None) + result = model(x) + """ ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_layer_rnn_RNNBase.py b/tests/test_nn_layer_rnn_RNNBase.py index 518db624f..724c0791d 100644 --- a/tests/test_nn_layer_rnn_RNNBase.py +++ b/tests/test_nn_layer_rnn_RNNBase.py @@ -103,6 +103,48 @@ def test_case_6(): obj.run(pytorch_code, ["model_args"]) +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + + model = torch.nn.RNNBase(mode='RNN_RELU', input_size=16, hidden_size=32, num_layers=2, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device='cpu', dtype=None) + + model_args = (model.input_size, model.hidden_size, model.num_layers) + """ + ) + obj.run(pytorch_code, ["model_args"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + + model = torch.nn.RNNBase('RNN_RELU', 16, 32, 2, True, False, 0.0, False, 0, 'cpu', None) + + model_args = (model.input_size, model.hidden_size, model.num_layers) + """ + ) + obj.run(pytorch_code, ["model_args"]) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + + model = torch.nn.RNNBase(mode='RNN_RELU', hidden_size=32, batch_first=False, num_layers=2, bias=True, input_size=16, dropout=0.0, bidirectional=False, proj_size=0, device='cpu', dtype=None) + + model_args = (model.input_size, model.hidden_size, model.num_layers) + """ + ) + obj.run(pytorch_code, ["model_args"]) + + def test_mode_case_1(): pytorch_code = textwrap.dedent( """