diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index c63890104..72003fe04 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -11054,11 +11054,13 @@ "layer_norm_eps", "batch_first", "norm_first", + "bias", "device", "dtype" ], "kwargs_change": { "norm_first": "normalize_before", + "bias": "bias_attr", "device": "", "dtype": "" }, @@ -11090,11 +11092,13 @@ "layer_norm_eps", "batch_first", "norm_first", + "bias", "device", "dtype" ], "kwargs_change": { "norm_first": "normalize_before", + "bias": "bias_attr", "device": "", "dtype": "" }, @@ -11134,11 +11138,13 @@ "layer_norm_eps", "batch_first", "norm_first", + "bias", "device", "dtype" ], "kwargs_change": { "norm_first": "normalize_before", + "bias": "bias_attr", "device": "", "dtype": "" }, @@ -12092,7 +12098,7 @@ } }, "torch.nn.functional.log_softmax": { - "Matcher": "RequireDimMatcher", + "Matcher": "SoftmaxMatcher", "paddle_api": "paddle.nn.functional.log_softmax", "args_list": [ "input", @@ -12215,9 +12221,6 @@ "kwargs_change": { "input": "x" }, - "unsupport_args": [ - "output_size" - ], "min_input_args": 3 }, "torch.nn.functional.max_unpool2d": { @@ -12605,7 +12608,7 @@ "min_input_args": 2 }, "torch.nn.functional.softmax": { - "Matcher": "RequireDimMatcher", + "Matcher": "SoftmaxMatcher", "paddle_api": "paddle.nn.functional.softmax", "args_list": [ "input", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 186f32d77..d2338cb3e 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -3829,8 +3829,8 @@ def generate_code(self, kwargs): class SoftmaxMatcher(BaseMatcher): def generate_code(self, kwargs): if "dim" not in kwargs or "None" in kwargs["dim"]: - return None - + kwargs.pop("dim", "None") + kwargs["axis"] = 0 return GenericMatcher.generate_code(self, kwargs) diff --git a/tests/test_nn_LogSoftmax.py b/tests/test_nn_LogSoftmax.py index 31e64aee3..794d8905a 100644 --- a/tests/test_nn_LogSoftmax.py +++ b/tests/test_nn_LogSoftmax.py @@ -70,9 +70,22 @@ def test_case_3(): result = model(x) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="When dim is None, paddle and pytorch generate different results", + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]]) + model = nn.LogSoftmax(dim=None) + result = model(x) + """ ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_Softmax.py b/tests/test_nn_Softmax.py index b7279379d..5f4a49d01 100644 --- a/tests/test_nn_Softmax.py +++ b/tests/test_nn_Softmax.py @@ -70,9 +70,22 @@ def test_case_3(): result = model(x) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="When dim is None, paddle and pytorch generate different results", + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 10.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]]) + model = nn.Softmax(dim=None) + result = model(x) + """ ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_Transformer.py b/tests/test_nn_Transformer.py index fab88af7e..06eec8c2b 100644 --- a/tests/test_nn_Transformer.py +++ b/tests/test_nn_Transformer.py @@ -148,3 +148,72 @@ def test_case_7(): unsupport=True, reason="paddle unsupport batch_first args", ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + transformer_model = torch.nn.Transformer(d_model=512, + nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, + dropout=0.1, activation='relu', + custom_encoder=None, custom_decoder=None, + layer_norm_eps=1e-05, batch_first=False, + norm_first=False, bias=False, + device=None, dtype=None) + src = torch.rand((10, 32, 512)) + tgt = torch.rand((10, 32, 512)) + result = transformer_model(src, tgt) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle unsupport layer_norm_eps args", + ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + transformer_model = torch.nn.Transformer(512, + 8, 6, 6, 2048, + 0.1, 'relu', + None, None, + 1e-05, False, + False, False, + None, None) + src = torch.rand((10, 32, 512)) + tgt = torch.rand((10, 32, 512)) + result = transformer_model(src, tgt) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle unsupport layer_norm_eps args", + ) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + transformer_model = torch.nn.Transformer(d_model=512, + nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, + dropout=0.1, activation='relu', + custom_encoder=None, custom_decoder=None, + norm_first=False, bias=True, device=None, dtype=None) + src = torch.rand((10, 32, 512)) + tgt = torch.rand((10, 32, 512)) + result = transformer_model(src, tgt) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) diff --git a/tests/test_nn_TransformerDecoderLayer.py b/tests/test_nn_TransformerDecoderLayer.py index bd773a830..a22767a5d 100644 --- a/tests/test_nn_TransformerDecoderLayer.py +++ b/tests/test_nn_TransformerDecoderLayer.py @@ -78,3 +78,60 @@ def test_case_4(): unsupport=True, reason="paddle unsupport batch_first args", ) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.ones(10, 32,512) + tgt = torch.ones(10, 32, 512) + model = nn.TransformerDecoderLayer(d_model=512, nhead=8,dim_feedforward=2048, dropout=0.1, + activation="relu", layer_norm_eps=1e-06, batch_first=False, + norm_first=False, bias=True, device=None, dtype=None) + result = model(tgt,x) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle unsupport batch_first args", + ) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.ones(10, 32,512) + tgt = torch.ones(10, 32, 512) + model = nn.TransformerDecoderLayer(512, 8,2048, 0.1, "relu", 1e-06, False, + False, True, None, None) + result = model(tgt,x) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle unsupport batch_first args", + ) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.ones(10, 32,512) + tgt = torch.ones(10, 32, 512) + model = nn.TransformerDecoderLayer(d_model=512, nhead=8,dim_feedforward=2048, dropout=0.1, + activation="relu", layer_norm_eps=1e-06, + norm_first=False, bias=True, device=None, dtype=None) + result = model(tgt,x) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) diff --git a/tests/test_nn_TransformerEncoderLayer.py b/tests/test_nn_TransformerEncoderLayer.py index 99153bddf..51ab45330 100644 --- a/tests/test_nn_TransformerEncoderLayer.py +++ b/tests/test_nn_TransformerEncoderLayer.py @@ -123,3 +123,64 @@ def test_case_7(): """ ) obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + tgt = torch.ones(10, 32, 512) + model = nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=2048, dropout=0.1, + activation="relu", layer_norm_eps=1e-05, batch_first=False, + norm_first=False, bias=True, device=None, dtype=None) + result = model(tgt) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle unsupport batch_first args", + ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + tgt = torch.ones(10, 32, 512) + model = nn.TransformerEncoderLayer(512, 8, 2048, 0.1, + "relu", 1e-05, False, + False, True, "cpu", None) + result = model(tgt) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle unsupport batch_first args", + ) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + tgt = torch.ones(10, 32, 512) + model = nn.TransformerEncoderLayer(512, 8, + 2048, + 0.1, 'relu', + norm_first=False, + device=None, + bias=True, + dtype=torch.float32) + result = model(tgt) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) diff --git a/tests/test_nn_Upsample.py b/tests/test_nn_Upsample.py index 0c551208b..a84fc9d1b 100644 --- a/tests/test_nn_Upsample.py +++ b/tests/test_nn_Upsample.py @@ -127,3 +127,21 @@ def test_case_6(): obj.run( pytorch_code, unsupport=True, reason="paddle unsupport recompute_scale_factor " ) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857], + [-1.2533, -0.9829, -1.0981], + [ 0.1507, -1.1431, -2.0361]], + + [[ 0.1024, -0.4482, 0.4137], + [ 0.9385, 0.4565, 0.7702], + [ 0.4135, -0.2587, 0.0482]]]]) + m = torch.nn.Upsample(scale_factor=2, align_corners=True, mode='bilinear') + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_functional_embedding.py b/tests/test_nn_functional_embedding.py index abc4a11e8..dabbbb7bd 100644 --- a/tests/test_nn_functional_embedding.py +++ b/tests/test_nn_functional_embedding.py @@ -58,11 +58,61 @@ def test_case_3(): pytorch_code = textwrap.dedent( """ import torch - w0 = torch.Tensor([[0., 0., 0.], - [1., 1., 1.], - [2., 2., 2.], - [3., 3., 3.]]) - result = torch.nn.functional.embedding(x,embedding_matrix,padding_idx=0,max_norm=2) + import numpy as np + embedding_matrix = torch.Tensor([[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + x = torch.tensor(np.array([[0,1],[2,3]])) + result = torch.nn.functional.embedding(x, embedding_matrix, padding_idx=0, max_norm=2) + """ + ) + obj.run(pytorch_code, unsupport=True, reason="paddle unsupport max_norm") + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import numpy as np + embedding_matrix = torch.Tensor([[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + x = torch.tensor(np.array([[0,1],[2,3]])) + result = torch.nn.functional.embedding(input=x, weight=embedding_matrix, padding_idx=0, max_norm=2, norm_type=2.0, scale_grad_by_freq=False, sparse=True) + """ + ) + obj.run(pytorch_code, unsupport=True, reason="paddle unsupport max_norm ") + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import numpy as np + embedding_matrix = torch.Tensor([[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + x = torch.tensor(np.array([[0,1],[2,3]])) + result = torch.nn.functional.embedding(input=x, padding_idx=0, max_norm=2, weight=embedding_matrix, scale_grad_by_freq=False, norm_type=2.0, sparse=True) + """ + ) + obj.run(pytorch_code, unsupport=True, reason="paddle unsupport max_norm ") + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import numpy as np + embedding_matrix = torch.Tensor([[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + x = torch.tensor(np.array([[0,1],[2,3]])) + result = torch.nn.functional.embedding(x, embedding_matrix, 0, 2, 2.0, False, True) """ ) - obj.run(pytorch_code, unsupport=True, reason="paddle unsupport") + obj.run(pytorch_code, unsupport=True, reason="paddle unsupport max_norm ") diff --git a/tests/test_nn_functional_interpolate.py b/tests/test_nn_functional_interpolate.py index 8252ec5c7..97a63c1b9 100644 --- a/tests/test_nn_functional_interpolate.py +++ b/tests/test_nn_functional_interpolate.py @@ -77,3 +77,59 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + + x = torch.tensor([[[1., 2., 3.], [2., 3., 4.]]]) + result = F.interpolate(input=x, size=None, scale_factor=3, mode='linear', align_corners=False, + recompute_scale_factor=False, antialias=False) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle unsupport parameter recompute_scale_factor", + ) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + + x = torch.tensor([[[1., 2., 3.], [2., 3., 4.]]]) + result = F.interpolate(input=x, scale_factor=3, size=None, recompute_scale_factor=False, mode='linear', align_corners=False, + antialias=False) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle unsupport parameter recompute_scale_factor", + ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + + x = torch.tensor([[[1., 2., 3.], [2., 3., 4.]]]) + result = F.interpolate(x, None, 3, 'linear', False, False, False) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle unsupport parameter recompute_scale_factor", + ) diff --git a/tests/test_nn_functional_log_softmax.py b/tests/test_nn_functional_log_softmax.py index 8ab2e71e8..892072f47 100644 --- a/tests/test_nn_functional_log_softmax.py +++ b/tests/test_nn_functional_log_softmax.py @@ -33,12 +33,7 @@ def test_case_1(): result = F.log_softmax(x) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="When dim is None, paddle and pytorch generate different results due to the way to calculate dimensions", - ) + obj.run(pytorch_code, ["result"]) def test_case_2(): @@ -72,12 +67,7 @@ def test_case_3(): result = F.log_softmax(x, dtype=torch.float64) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="When dim is None, paddle and pytorch generate different results due to the way to calculate dimensions", - ) + obj.run(pytorch_code, ["result"]) def test_case_4(): @@ -94,9 +84,55 @@ def test_case_4(): result = F.log_softmax(x, _stacklevel=2) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="When dim is None, paddle and pytorch generate different results due to the way to calculate dimensions", + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + x = torch.tensor([[[-2.0, 3.0, -4.0, 5.0], + [3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [6.0, 7.0, 8.0, 9.0]]]) + result = F.log_softmax(input=x, dim=2, _stacklevel=2, dtype=torch.float64) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + x = torch.tensor([[[-2.0, 3.0, -4.0, 5.0], + [3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [6.0, 7.0, 8.0, 9.0]]]) + result = F.log_softmax(x, 2, 2, torch.float64) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + x = torch.tensor([[[-2.0, 3.0, -4.0, 5.0], + [3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [6.0, 7.0, 8.0, 9.0]]]) + result = F.log_softmax(input=x, _stacklevel=2, dtype=torch.float64, dim=-2) + """ ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_functional_max_pool1d.py b/tests/test_nn_functional_max_pool1d.py index e85b7ac45..43c51e41d 100644 --- a/tests/test_nn_functional_max_pool1d.py +++ b/tests/test_nn_functional_max_pool1d.py @@ -97,7 +97,7 @@ def test_case_5(): # when return_indices=False, paddle result and indices shape is (1, 3, 2), which is right: ceil(6/5)=2 # when return_indices=True, paddle result and indices shape is (1, 3, 1), which is bug -def _test_case_6(): +def test_case_6(): pytorch_code = textwrap.dedent( """ import torch @@ -114,3 +114,60 @@ def _test_case_6(): check_dtype=False, reason="torch indices dtype is int64, while paddle is int32", ) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + input = torch.tensor([[[ 1.1524, 0.4714, 0.2857, 0.4586, 0.9876, 0.5487], + [-1.2533, -0.9829, -1.0981, 0.7655, 0.8541, 0.9873], + [ 0.1507, -1.1431, -2.0361, 0.2344, 0.5675, 0.1546]]]) + result, indices = F.max_pool1d(input=input, kernel_size=5, stride=2, padding=0, dilation=1, ceil_mode=True, return_indices=True) + """ + ) + obj.run( + pytorch_code, + ["result", "indices"], + unsupport=True, + reason="paddle unsupport parameter dilation", + ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + input = torch.tensor([[[ 1.1524, 0.4714, 0.2857, 0.4586, 0.9876, 0.5487], + [-1.2533, -0.9829, -1.0981, 0.7655, 0.8541, 0.9873], + [ 0.1507, -1.1431, -2.0361, 0.2344, 0.5675, 0.1546]]]) + result, indices = F.max_pool1d(kernel_size=5, stride=2, dilation=1, input=input, padding=0, ceil_mode=True, return_indices=True) + """ + ) + obj.run( + pytorch_code, + ["result", "indices"], + unsupport=True, + reason="paddle unsupport parameter dilation", + ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + input = torch.tensor([[[ 1.1524, 0.4714, 0.2857, 0.4586, 0.9876, 0.5487], + [-1.2533, -0.9829, -1.0981, 0.7655, 0.8541, 0.9873], + [ 0.1507, -1.1431, -2.0361, 0.2344, 0.5675, 0.1546]]]) + result, indices = F.max_pool1d(input, 5, 2, 0, 1, True, True) + """ + ) + obj.run( + pytorch_code, + ["result", "indices"], + unsupport=True, + reason="paddle unsupport parameter dilation", + ) diff --git a/tests/test_nn_functional_max_pool2d.py b/tests/test_nn_functional_max_pool2d.py index 8fae90bc9..f942f63cd 100644 --- a/tests/test_nn_functional_max_pool2d.py +++ b/tests/test_nn_functional_max_pool2d.py @@ -138,7 +138,7 @@ def test_case_6(): # when return_indices=False, paddle result and indices shape is (1, 3, 2, 2), which is right: ceil(6/5)=2 # when return_indices=True, paddle result and indices shape is (1, 3, 1, 1), which is bug -def _test_case_7(): +def test_case_7(): pytorch_code = textwrap.dedent( """ import torch @@ -153,3 +153,77 @@ def _test_case_7(): check_dtype=False, reason="torch indices dtype is int64, while paddle is int32", ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + + input = torch.tensor([[[[1.1524, 0.4714, 0.2857], + [-1.2533, -0.9829, -1.0981], + [0.1507, -1.1431, -2.0361]], + + [[0.1024, -0.4482, 0.4137], + [0.9385, 0.4565, 0.7702], + [0.4135, -0.2587, 0.0482]]]]) + result, indices = F.max_pool2d(input=input, kernel_size=(2, 2), stride=2, padding=1, dilation=2, ceil_mode=False, + return_indices=True) + """ + ) + obj.run( + pytorch_code, + ["result", "indices"], + unsupport=True, + reason="dilation is not supported now", + ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + + input = torch.tensor([[[[1.1524, 0.4714, 0.2857], + [-1.2533, -0.9829, -1.0981], + [0.1507, -1.1431, -2.0361]], + + [[0.1024, -0.4482, 0.4137], + [0.9385, 0.4565, 0.7702], + [0.4135, -0.2587, 0.0482]]]]) + result, indices = F.max_pool2d(input, (2, 2), 2, 1, 2, False, True) + """ + ) + obj.run( + pytorch_code, + ["result", "indices"], + unsupport=True, + reason="dilation is not supported now", + ) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + + input = torch.tensor([[[[1.1524, 0.4714, 0.2857], + [-1.2533, -0.9829, -1.0981], + [0.1507, -1.1431, -2.0361]], + + [[0.1024, -0.4482, 0.4137], + [0.9385, 0.4565, 0.7702], + [0.4135, -0.2587, 0.0482]]]]) + result, indices = F.max_pool2d(input=input, stride=2, padding=1, dilation=2, kernel_size=(2, 2), ceil_mode=False, + return_indices=True) + """ + ) + obj.run( + pytorch_code, + ["result", "indices"], + unsupport=True, + reason="dilation is not supported now", + ) diff --git a/tests/test_nn_functional_max_pool3d.py b/tests/test_nn_functional_max_pool3d.py index 64b5b7d35..3b76d9b38 100644 --- a/tests/test_nn_functional_max_pool3d.py +++ b/tests/test_nn_functional_max_pool3d.py @@ -102,7 +102,7 @@ def test_case_6(): # when return_indices=False, paddle result and indices shape is (1, 3, 2, 2, 2), which is right: ceil(10/8)=2 # when return_indices=True, paddle result and indices shape is (1, 3, 1, 1, 1), which is bug -def _test_case_7(): +def test_case_7(): pytorch_code = textwrap.dedent( """ import torch @@ -117,3 +117,17 @@ def _test_case_7(): check_dtype=False, reason="torch indices dtype is int64, while paddle is int32", ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + input = torch.arange(4800, dtype=torch.float32).reshape(2, 3, 8, 10, 10) + result = F.max_pool3d(input, 3, 1, 1, 2, True, False) + """ + ) + obj.run( + pytorch_code, ["result"], unsupport=True, reason="dilation is not supported now" + ) diff --git a/tests/test_nn_functional_max_unpool1d.py b/tests/test_nn_functional_max_unpool1d.py index 48e046558..c254f95b2 100644 --- a/tests/test_nn_functional_max_unpool1d.py +++ b/tests/test_nn_functional_max_unpool1d.py @@ -74,12 +74,7 @@ def test_case_4(): result = F.max_unpool1d(x, indices, kernel_size=2, output_size=(1, 1, 4)) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle will generate error when the output_size parameter is specified", - ) + obj.run(pytorch_code, ["result"]) def test_case_5(): @@ -95,3 +90,66 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + + x = torch.tensor([[[0.58987975, 0.80133516, 0.71605772, 0.46068805, 0.30434567, 0.41771618, 0.15606387, 0.88071585], + [0.67178625, 0.54522562, 0.83222342, 0.26114768, 0.77833325, 0.52892995, 0.26498035, 0.97040081]]]) + indices = torch.tensor([[[1, 3, 4, 7, 8, 10, 13, 14], + [1, 2, 5, 6, 8, 11, 13, 14]]]) + result = F.max_unpool1d(input=x, indices=indices, kernel_size=2, stride=2, padding=0, output_size=(15,15,15)) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + + x = torch.tensor([[[0.58987975, 0.80133516, 0.71605772, 0.46068805, 0.30434567, 0.41771618, 0.15606387, 0.88071585], + [0.67178625, 0.54522562, 0.83222342, 0.26114768, 0.77833325, 0.52892995, 0.26498035, 0.97040081]]]) + indices = torch.tensor([[[1, 3, 4, 7, 8, 10, 13, 14], + [1, 2, 5, 6, 8, 11, 13, 14]]]) + result = F.max_unpool1d(x, indices, 2, 2, 0, (15,15,15)) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + + x = torch.tensor([[[0.58987975, 0.80133516, 0.71605772, 0.46068805, 0.30434567, 0.41771618, 0.15606387, 0.88071585], + [0.67178625, 0.54522562, 0.83222342, 0.26114768, 0.77833325, 0.52892995, 0.26498035, 0.97040081]]]) + indices = torch.tensor([[[1, 3, 4, 7, 8, 10, 13, 14], + [1, 2, 5, 6, 8, 11, 13, 14]]]) + result = F.max_unpool1d(input=x, kernel_size=2, indices=indices, padding=0, stride=2, output_size=(15,15,15)) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + x = torch.tensor([[[0.58987975, 0.80133516, 0.71605772, 0.46068805, 0.30434567, 0.41771618, 0.15606387, 0.88071585], + [0.67178625, 0.54522562, 0.83222342, 0.26114768, 0.77833325, 0.52892995, 0.26498035, 0.97040081]]]) + indices = torch.tensor([[[1 , 3 , 4 , 7 , 8 , 10, 13, 14], + [1 , 2 , 5 , 6 , 8 , 11, 13, 14]]]) + result = F.max_unpool1d(x, indices, kernel_size=2) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_functional_softmax.py b/tests/test_nn_functional_softmax.py index c685ce0aa..ff47d5eee 100644 --- a/tests/test_nn_functional_softmax.py +++ b/tests/test_nn_functional_softmax.py @@ -67,12 +67,7 @@ def test_case_3(): result = F.softmax(x) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="When dim is None, paddle and pytorch generate different results due to the way to calculate dimensions", - ) + obj.run(pytorch_code, ["result"]) def test_case_4(): @@ -123,9 +118,55 @@ def test_case_6(): result = F.softmax(x, _stacklevel=2) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="When dim is None, paddle and pytorch generate different results due to the way to calculate dimensions", + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + x = torch.tensor([[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]]) + result = F.softmax(input=x, dim=1, _stacklevel=2, dtype=torch.float64) + """ ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + x = torch.tensor([[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]]) + result = F.softmax(x, 1, 2, torch.float64) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + x = torch.tensor([[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]]) + result = F.softmax(input=x, _stacklevel=2, dim=1, dtype=torch.float64) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_functional_upsample.py b/tests/test_nn_functional_upsample.py index fdab971ac..db961b18b 100644 --- a/tests/test_nn_functional_upsample.py +++ b/tests/test_nn_functional_upsample.py @@ -95,3 +95,29 @@ def test_case_6(): """ ) obj.run(pytorch_code, unsupport=True, reason="align_corners is not supported") + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + x = torch.tensor([[[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]]) + result = F.upsample(input=x, size=None, scale_factor=2.0, mode='nearest', align_corners=None) + """ + ) + obj.run(pytorch_code, unsupport=True, reason="align_corners is not supported") + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + x = torch.tensor([[[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]]) + result = F.upsample(x, None, 2.0, 'nearest', None) + """ + ) + obj.run(pytorch_code, unsupport=True, reason="align_corners is not supported") diff --git a/tests/test_nn_functional_upsample_bilinear.py b/tests/test_nn_functional_upsample_bilinear.py index fae11202a..a8101af8c 100644 --- a/tests/test_nn_functional_upsample_bilinear.py +++ b/tests/test_nn_functional_upsample_bilinear.py @@ -71,3 +71,57 @@ def test_case_3(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch.nn.functional as F + import torch + input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857], + [-1.2533, -0.9829, -1.0981], + [ 0.1507, -1.1431, -2.0361]], + + [[ 0.1024, -0.4482, 0.4137], + [ 0.9385, 0.4565, 0.7702], + [ 0.4135, -0.2587, 0.0482]]]]) + result = F.upsample_bilinear(input=input, size=None, scale_factor=2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch.nn.functional as F + import torch + input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857], + [-1.2533, -0.9829, -1.0981], + [ 0.1507, -1.1431, -2.0361]], + + [[ 0.1024, -0.4482, 0.4137], + [ 0.9385, 0.4565, 0.7702], + [ 0.4135, -0.2587, 0.0482]]]]) + result = F.upsample_bilinear(input, None, 2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch.nn.functional as F + import torch + input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857], + [-1.2533, -0.9829, -1.0981], + [ 0.1507, -1.1431, -2.0361]], + + [[ 0.1024, -0.4482, 0.4137], + [ 0.9385, 0.4565, 0.7702], + [ 0.4135, -0.2587, 0.0482]]]]) + result = F.upsample_bilinear(size=None, input=input, scale_factor=2) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_functional_upsample_nearest.py b/tests/test_nn_functional_upsample_nearest.py index c06a0051e..03aad9972 100644 --- a/tests/test_nn_functional_upsample_nearest.py +++ b/tests/test_nn_functional_upsample_nearest.py @@ -71,3 +71,57 @@ def test_case_3(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch.nn.functional as F + import torch + input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857], + [-1.2533, -0.9829, -1.0981], + [ 0.1507, -1.1431, -2.0361]], + + [[ 0.1024, -0.4482, 0.4137], + [ 0.9385, 0.4565, 0.7702], + [ 0.4135, -0.2587, 0.0482]]]]) + result = F.upsample_nearest(input=input, size=None, scale_factor=2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch.nn.functional as F + import torch + input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857], + [-1.2533, -0.9829, -1.0981], + [ 0.1507, -1.1431, -2.0361]], + + [[ 0.1024, -0.4482, 0.4137], + [ 0.9385, 0.4565, 0.7702], + [ 0.4135, -0.2587, 0.0482]]]]) + result = F.upsample_nearest(input=input, scale_factor=2, size=None) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch.nn.functional as F + import torch + input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857], + [-1.2533, -0.9829, -1.0981], + [ 0.1507, -1.1431, -2.0361]], + + [[ 0.1024, -0.4482, 0.4137], + [ 0.9385, 0.4565, 0.7702], + [ 0.4135, -0.2587, 0.0482]]]]) + result = F.upsample_nearest(input, None, 2) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_norm.py b/tests/test_norm.py index 68699c4b0..6ef7dfa10 100644 --- a/tests/test_norm.py +++ b/tests/test_norm.py @@ -117,3 +117,47 @@ def test_case_7(): """ ) obj.run(pytorch_code, ["result", "out"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[-12., -11., -10., -9. ], + [-8. , -7. , -6. , -5. ], + [-4. , -3. , -2. , -1. ]]) + out = torch.tensor([1.], dtype=torch.float64) + result = torch.norm(input=input, p=2, dim=1, keepdim=True, out=out, dtype=torch.float64) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[-12., -11., -10., -9. ], + [-8. , -7. , -6. , -5. ], + [-4. , -3. , -2. , -1. ]]) + out = torch.tensor([1.], dtype=torch.float64) + result = torch.norm(input=input, keepdim=True, dim=1, p=2, out=out, dtype=torch.float64) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[-12., -11., -10., -9. ], + [-8. , -7. , -6. , -5. ], + [-4. , -3. , -2. , -1. ]], + [[ 0. , 1. , 2. , 3. ], + [ 4. , 5. , 6. , 7. ], + [ 8. , 9. , 10., 11.]]]) + result = torch.norm(input) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_optim_Adagrad.py b/tests/test_optim_Adagrad.py index 8682fd819..bd563c191 100644 --- a/tests/test_optim_Adagrad.py +++ b/tests/test_optim_Adagrad.py @@ -92,3 +92,29 @@ def test_case_8(): unsupport=True, reason="`lr_decay`, `foreach`, 'maximize` and `differentiable` is not supported.", ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.Adagrad(conv.parameters(), 0.01, 0, 0, 0, 1e-10, None, maximize=False, differentiable=False)" + ) + ) + obj.run( + pytorch_code, + unsupport=True, + reason="`lr_decay`, `foreach`, 'maximize` and `differentiable` is not supported.", + ) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.Adagrad(params=conv.parameters(), lr_decay=0, lr=0.01, initial_accumulator_value=0, weight_decay=0, eps=1e-10, foreach=None, maximize=False, differentiable=False)" + ) + ) + obj.run( + pytorch_code, + unsupport=True, + reason="`lr_decay`, `foreach`, 'maximize` and `differentiable` is not supported.", + ) diff --git a/tests/test_optim_SGD.py b/tests/test_optim_SGD.py index 0cad01c04..262f17b38 100644 --- a/tests/test_optim_SGD.py +++ b/tests/test_optim_SGD.py @@ -90,3 +90,29 @@ def test_case_8(): unsupport=True, reason="`momentum`, `dampening`, `nesterov`, `maximize`, `foreach` and `differentiable` is not supported.", ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.SGD(conv.parameters(), 0.8, 0, 0, 0, False, maximize=False, foreach=None, differentiable=False)" + ) + ) + obj.run( + pytorch_code, + unsupport=True, + reason="`momentum`, `dampening`, `nesterov`, `maximize`, `foreach` and `differentiable` is not supported.", + ) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.SGD(params=conv.parameters(), lr=0.8, weight_decay=0, momentum=0, dampening=0, maximize=False, nesterov=False, foreach=None, differentiable=False)" + ) + ) + obj.run( + pytorch_code, + unsupport=True, + reason="`momentum`, `dampening`, `nesterov`, `maximize`, `foreach` and `differentiable` is not supported.", + )