From 469a523b13ad235794356382779cc447f0197d7c Mon Sep 17 00:00:00 2001 From: starfall <1186454801@qq.com> Date: Wed, 18 Sep 2024 14:09:15 +0800 Subject: [PATCH 01/10] =?UTF-8?q?=E3=80=90Hackathon=207th=20No.39=E3=80=91?= =?UTF-8?q?=E4=B8=BA=20Paddle=20=E4=BB=A3=E7=A0=81=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=E6=96=B0=E5=A2=9E=20API=20=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E8=A7=84=E5=88=99=EF=BC=88=E7=AC=AC=206=20=E7=BB=84=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paconvert/api_mapping.json | 198 ++++++++++++++++++ paconvert/api_matcher.py | 30 +++ tests/test_nn_AdaptiveLogSoftmaxWithLoss.py | 83 ++++++++ tests/test_nn_CircularPad3d.py | 83 ++++++++ tests/test_nn_LPPool1d.py | 132 ++++++++++++ tests/test_nn_LPPool2d.py | 163 ++++++++++++++ tests/test_nn_Softmin.py | 91 ++++++++ ...est_nn_functional_feature_alpha_dropout.py | 98 +++++++++ tests/test_nn_functional_lp_pool1d.py | 108 ++++++++++ tests/test_nn_functional_lp_pool2d.py | 138 ++++++++++++ tests/test_nn_functional_threshold_.py | 84 ++++++++ ...t_nn_utils_parametrizations_weight_norm.py | 101 +++++++++ tests/test_optim_NAdam.py | 88 ++++++++ tests/test_optim_RAdam.py | 79 +++++++ 14 files changed, 1476 insertions(+) create mode 100644 tests/test_nn_AdaptiveLogSoftmaxWithLoss.py create mode 100644 tests/test_nn_CircularPad3d.py create mode 100644 tests/test_nn_LPPool1d.py create mode 100644 tests/test_nn_LPPool2d.py create mode 100644 tests/test_nn_Softmin.py create mode 100644 tests/test_nn_functional_feature_alpha_dropout.py create mode 100644 tests/test_nn_functional_lp_pool1d.py create mode 100644 tests/test_nn_functional_lp_pool2d.py create mode 100644 tests/test_nn_functional_threshold_.py create mode 100644 tests/test_nn_utils_parametrizations_weight_norm.py create mode 100644 tests/test_optim_NAdam.py create mode 100644 tests/test_optim_RAdam.py diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 10232f952..747982b12 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -9277,6 +9277,20 @@ "output_size" ] }, + "torch.nn.AdaptiveLogSoftmaxWithLoss": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.nn.AdaptiveLogSoftmaxWithLoss", + "min_input_args": 3, + "args_list": [ + "in_features", + "n_classes", + "cutoffs", + "div_value", + "head_bias", + "device", + "dtype" + ] + }, "torch.nn.AdaptiveMaxPool1d": { "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.AdaptiveMaxPool1D", @@ -9490,6 +9504,17 @@ "groups" ] }, + "torch.nn.CircularPad3d": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.nn.Pad3D", + "min_input_args": 1, + "args_list": [ + "padding" + ], + "paddle_default_kwargs": { + "mode": "'circular'" + } + }, "torch.nn.ConstantPad1d": { "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.Pad1D", @@ -10077,6 +10102,28 @@ ], "min_input_args": 0 }, + "torch.nn.LPPool1d": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.nn.LPPool1D", + "min_input_args": 2, + "args_list": [ + "norm_type", + "kernel_size", + "stride", + "ceil_mode" + ] + }, + "torch.nn.LPPool2d": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.nn.LPPool2D", + "min_input_args": 2, + "args_list": [ + "norm_type", + "kernel_size", + "stride", + "ceil_mode" + ] + }, "torch.nn.LSTM": { "Matcher": "RNNMatcher", "paddle_api": "paddle.nn.LSTM", @@ -11045,6 +11092,20 @@ }, "min_input_args": 0 }, + "torch.nn.Softmin": { + "Matcher": "SoftminMatcher", + "paddle_api": "paddle.nn.Softmax", + "args_list": [ + "dim" + ], + "kwargs_change": { + "dim": "axis" + }, + "paddle_default_kwargs": { + "axis": 0 + }, + "min_input_args": 0 + }, "torch.nn.Softplus": { "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.Softplus", @@ -11827,6 +11888,20 @@ ], "min_input_args": 2 }, + "torch.nn.functional.feature_alpha_dropout": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.nn.functional.feature_alpha_dropout", + "min_input_args": 1, + "args_list": [ + "input", + "p", + "training", + "inplace" + ], + "kwargs_change": { + "input": "x" + } + }, "torch.nn.functional.fold": { "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.functional.fold", @@ -12206,6 +12281,36 @@ "input": "x" } }, + "torch.nn.functional.lp_pool1d": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.nn.functional.lp_pool1d", + "min_input_args": 2, + "args_list": [ + "input", + "norm_type", + "kernel_size", + "stride", + "ceil_mode" + ], + "kwargs_change": { + "input": "x" + } + }, + "torch.nn.functional.lp_pool2d": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.nn.functional.lp_pool2d", + "min_input_args": 2, + "args_list": [ + "input", + "norm_type", + "kernel_size", + "stride", + "ceil_mode" + ], + "kwargs_change": { + "input": "x" + } + }, "torch.nn.functional.margin_ranking_loss": { "Matcher": "SizeAverageMatcher", "paddle_api": "paddle.nn.functional.margin_ranking_loss", @@ -12795,6 +12900,19 @@ "input": "x" } }, + "torch.nn.functional.threshold_": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.nn.functional.thresholded_relu_", + "min_input_args": 3, + "args_list": [ + "input", + "threshold", + "value" + ], + "kwargs_change": { + "input": "x" + } + }, "torch.nn.functional.triplet_margin_loss": { "Matcher": "SizeAverageMatcher", "paddle_api": "paddle.nn.functional.triplet_margin_loss", @@ -13186,6 +13304,19 @@ }, "min_input_args": 1 }, + "torch.nn.utils.parametrizations.weight_norm": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.nn.utils.weight_norm", + "args_list": [ + "module", + "name", + "dim" + ], + "kwargs_change": { + "module": "layer" + }, + "min_input_args": 1 + }, "torch.nn.utils.remove_weight_norm": { "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.utils.remove_weight_norm", @@ -13535,6 +13666,40 @@ "lr": "learning_rate" } }, + "torch.optim.NAdam": { + "Matcher": "OptimAdamMatcher", + "paddle_api": "paddle.optimizer.NAdam", + "min_input_args": 1, + "args_list": [ + "params", + "lr", + "betas", + "eps", + "weight_decay", + "momentum_decay", + "decoupled_weight_decay", + "*", + "foreach", + "maximize", + "capturable", + "differentiable" + ], + "unsupport_args": [ + "decoupled_weight_decay", + "foreach", + "maximize", + "capturable", + "differentiable" + ], + "kwargs_change": { + "params": "parameters", + "lr": "learning_rate", + "eps": "epsilon" + }, + "paddle_default_kwargs": { + "weight_decay": 0.0 + } + }, "torch.optim.Optimizer": { "Matcher": "OptimOptimizerMatcher", "paddle_api": "paddle.optimizer.Optimizer", @@ -13575,6 +13740,39 @@ "torch.optim.Optimizer.zero_grad": { "min_input_args": 0 }, + "torch.optim.RAdam": { + "Matcher": "OptimAdamMatcher", + "paddle_api": "paddle.optimizer.RAdam", + "min_input_args": 1, + "args_list": [ + "params", + "lr", + "betas", + "eps", + "weight_decay", + "decoupled_weight_decay", + "*", + "foreach", + "maximize", + "capturable", + "differentiable" + ], + "unsupport_args": [ + "decoupled_weight_decay", + "foreach", + "maximize", + "capturable", + "differentiable" + ], + "kwargs_change": { + "params": "parameters", + "lr": "learning_rate", + "eps": "epsilon" + }, + "paddle_default_kwargs": { + "weight_decay": 0.0 + } + }, "torch.optim.RMSprop": { "Matcher": "GenericMatcher", "paddle_api": "paddle.optimizer.RMSProp", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 4b4600203..6802426ca 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -3921,6 +3921,36 @@ def generate_code(self, kwargs): return GenericMatcher.generate_code(self, kwargs) +class SoftminMatcher(SoftmaxMatcher): + def generate_code(self, kwargs): + self.paddle_api = "paddle.nn.Softmin" + return super().generate_code(kwargs) + + def generate_aux_code(self): + CODE_TEMPLATE = textwrap.dedent( + """ + def _get_softmax_dim(axis: int) -> int: + if axis == 0 or axis == 1 or axis == 3: + ret = 0 + else: + ret = 1 + return ret + + def forward(self,x): + if self._axis is None: + return paddle.nn.functional.softmax(x, _get_softmax_dim(x.ndim)) + return paddle.nn.functional.softmax(x, self._axis) + setattr(paddle.nn.Softmax, 'forward', forward) + + class Softmin(paddle.nn.Softmax): + def forward(self, x): + return super().forward(-x) + setattr(paddle.nn, 'Softmin', Softmin) + """ + ) + return CODE_TEMPLATE + + class OptimOptimizerMatcher(BaseMatcher): def generate_code(self, kwargs): code = "paddle.optimizer.Optimizer(parameters={}, **{})".format( diff --git a/tests/test_nn_AdaptiveLogSoftmaxWithLoss.py b/tests/test_nn_AdaptiveLogSoftmaxWithLoss.py new file mode 100644 index 000000000..e5d82f25e --- /dev/null +++ b/tests/test_nn_AdaptiveLogSoftmaxWithLoss.py @@ -0,0 +1,83 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.AdaptiveLogSoftmaxWithLoss") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[ 0.9368637 , -0.0361056 , -0.98917043, 0.06605113, 1.5254455 ], + [-1.0518035 , -1.0024613 , 0.18699688, -0.35807893, 0.25628588], + [-0.900478 , -0.41495147, 0.84707606, -1.7883497 , 1.3243382 ]]) + target = torch.tensor([1, 1, 1]) + asfm = torch.nn.AdaptiveLogSoftmaxWithLoss(5, 4, [2]) + out, loss = asfm(input,target) + """ + ) + obj.run(pytorch_code, ["out", "loss"], check_value=False) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[ 0.9368637 , -0.0361056 , -0.98917043, 0.06605113, 1.5254455 ], + [-1.0518035 , -1.0024613 , 0.18699688, -0.35807893, 0.25628588], + [-0.900478 , -0.41495147, 0.84707606, -1.7883497 , 1.3243382 ]]) + target = torch.tensor([1, 1, 1]) + asfm = torch.nn.AdaptiveLogSoftmaxWithLoss(5, 4, [3], div_value=2.0) + out, loss = asfm(input,target) + """ + ) + obj.run(pytorch_code, ["out", "loss"], check_value=False) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[ 0.9368637 , -0.0361056 , -0.98917043, 0.06605113, 1.5254455 ], + [-1.0518035 , -1.0024613 , 0.18699688, -0.35807893, 0.25628588], + [-0.900478 , -0.41495147, 0.84707606, -1.7883497 , 1.3243382 ]]) + target = torch.tensor([1, 1, 1]) + asfm = torch.nn.AdaptiveLogSoftmaxWithLoss(5, 4, [1], div_value=3.8, head_bias=True) + out, loss = asfm(input,target) + """ + ) + obj.run(pytorch_code, ["out", "loss"], check_value=False) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[ 0.9368637 , -0.0361056 , -0.98917043, 0.06605113, 1.5254455 ], + [-1.0518035 , -1.0024613 , 0.18699688, -0.35807893, 0.25628588], + [-0.900478 , -0.41495147, 0.84707606, -1.7883497 , 1.3243382 ]]) + target = torch.tensor([1, 1, 1]) + asfm = torch.nn.AdaptiveLogSoftmaxWithLoss(in_features=5, n_classes=8, cutoffs=[5], div_value=3.8, head_bias=True) + out, loss = asfm(input,target) + """ + ) + obj.run(pytorch_code, ["out", "loss"], check_value=False) diff --git a/tests/test_nn_CircularPad3d.py b/tests/test_nn_CircularPad3d.py new file mode 100644 index 000000000..2130b8dc8 --- /dev/null +++ b/tests/test_nn_CircularPad3d.py @@ -0,0 +1,83 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.CircularPad3d") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[[[-1.3328, -0.4948], + [ 0.8689, 1.1423]], + [[-0.2671, -1.0868], + [ 1.3011, 1.0469]]]]]) + model = nn.CircularPad3d(1) + result = model(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[[[-1.3328, -0.4948], + [ 0.8689, 1.1423]], + [[-0.2671, -1.0868], + [ 1.3011, 1.0469]]]]]) + model = nn.CircularPad3d((1, 1, 1, 1, 1, 1)) + result = model(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[[[-1.3328, -0.4948], + [ 0.8689, 1.1423]], + [[-0.2671, -1.0868], + [ 1.3011, 1.0469]]]]]) + model = nn.CircularPad3d(padding=1) + result = model(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[[[-1.3328, -0.4948], + [ 0.8689, 1.1423]], + [[-0.2671, -1.0868], + [ 1.3011, 1.0469]]]]]) + model = torch.nn.CircularPad3d(padding=(1, 2, 1, 2, 2, 1)) + result = model(x) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_LPPool1d.py b/tests/test_nn_LPPool1d.py new file mode 100644 index 000000000..fb89d91a7 --- /dev/null +++ b/tests/test_nn_LPPool1d.py @@ -0,0 +1,132 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.LPPool1d") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[[-0.5743, 0.4889, -0.0878, 0.4210, -0.0844], + [ 0.3614, 0.8458, -0.6152, 0.6894, 0.2927], + [-0.0087, 0.1098, 0.1783, -0.6953, 0.5519], + [ 0.3789, -0.0560, -0.4090, -0.1070, -1.0139], + [ 0.9204, 1.0817, -2.6126, 0.4244, 0.3272]]]) + pool = torch.nn.LPPool1d(1, 3, stride=2) + result = pool(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[[ 0.6430, 0.4511, -1.6757, 1.7116], + [-0.2288, -0.4111, -1.3602, 0.2685], + [ 0.2363, 1.9341, 0.8522, -0.1846], + [ 1.6496, -0.0675, -0.7208, -1.0018]], + + [[-0.3183, 0.8029, -0.4993, 1.0598], + [-0.4952, -0.9536, 0.1954, 0.0551], + [ 1.2257, 0.7517, 0.4063, -1.2151], + [-1.3562, 0.3547, 1.1147, 1.2898]], + + [[ 0.1205, -0.1889, 0.5086, -0.8080], + [ 0.3156, -0.8298, 2.0242, -0.9184], + [-0.4005, 1.3586, 0.6205, -0.7487], + [ 1.6239, 0.2900, 0.9671, 1.2961]], + + [[-1.1996, -0.2201, -0.9466, -0.7264], + [-0.0313, 0.8284, -0.3588, 1.3522], + [-0.0991, -0.5112, -0.1785, 2.0903], + [-1.3286, -0.9333, -0.1404, 1.2582]]]) + pool = torch.nn.LPPool1d(2, 4, stride=2) + result = pool(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[[-0.5743, 0.4889, -0.0878, 0.4210, -0.0844], + [ 0.3614, 0.8458, -0.6152, 0.6894, 0.2927], + [-0.0087, 0.1098, 0.1783, -0.6953, 0.5519], + [ 0.3789, -0.0560, -0.4090, -0.1070, -1.0139], + [ 0.9204, 1.0817, -2.6126, 0.4244, 0.3272]]]) + pool = torch.nn.LPPool1d(float('inf'), 3, stride=2, ceil_mode=True) + result = pool(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[[-0.5743, 0.4889, -0.0878, 0.4210, -0.0844], + [ 0.3614, 0.8458, -0.6152, 0.6894, 0.2927], + [-0.0087, 0.1098, 0.1783, -0.6953, 0.5519], + [ 0.3789, -0.0560, -0.4090, -0.1070, -1.0139], + [ 0.9204, 1.0817, -2.6126, 0.4244, 0.3272]]]) + pool = torch.nn.LPPool1d(10, 3, stride=2, ceil_mode=True) + result = pool(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[[ 0.6430, 0.4511, -1.6757, 1.7116], + [-0.2288, -0.4111, -1.3602, 0.2685], + [ 0.2363, 1.9341, 0.8522, -0.1846], + [ 1.6496, -0.0675, -0.7208, -1.0018]], + + [[-0.3183, 0.8029, -0.4993, 1.0598], + [-0.4952, -0.9536, 0.1954, 0.0551], + [ 1.2257, 0.7517, 0.4063, -1.2151], + [-1.3562, 0.3547, 1.1147, 1.2898]], + + [[ 0.1205, -0.1889, 0.5086, -0.8080], + [ 0.3156, -0.8298, 2.0242, -0.9184], + [-0.4005, 1.3586, 0.6205, -0.7487], + [ 1.6239, 0.2900, 0.9671, 1.2961]], + + [[-1.1996, -0.2201, -0.9466, -0.7264], + [-0.0313, 0.8284, -0.3588, 1.3522], + [-0.0991, -0.5112, -0.1785, 2.0903], + [-1.3286, -0.9333, -0.1404, 1.2582]]]) + pool = nn.LPPool1d(norm_type=2, kernel_size=2, stride=2, ceil_mode=True) + result = pool(input) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_LPPool2d.py b/tests/test_nn_LPPool2d.py new file mode 100644 index 000000000..e0e0f7d2c --- /dev/null +++ b/tests/test_nn_LPPool2d.py @@ -0,0 +1,163 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.LPPool2d") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[[[-0.5743, 0.4889, -0.0878, 0.4210, -0.0844], + [ 0.3614, 0.8458, -0.6152, 0.6894, 0.2927], + [-0.0087, 0.1098, 0.1783, -0.6953, 0.5519], + [ 0.3789, -0.0560, -0.4090, -0.1070, -1.0139], + [ 0.9204, 1.0817, -2.6126, 0.4244, 0.3272]]]]) + pool = torch.nn.LPPool2d(1, 3, stride=2) + result = pool(input) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[[[ 0.6430, 0.4511, -1.6757, 1.7116], + [-0.2288, -0.4111, -1.3602, 0.2685], + [ 0.2363, 1.9341, 0.8522, -0.1846], + [ 1.6496, -0.0675, -0.7208, -1.0018]], + + [[-0.3183, 0.8029, -0.4993, 1.0598], + [-0.4952, -0.9536, 0.1954, 0.0551], + [ 1.2257, 0.7517, 0.4063, -1.2151], + [-1.3562, 0.3547, 1.1147, 1.2898]], + + [[ 0.1205, -0.1889, 0.5086, -0.8080], + [ 0.3156, -0.8298, 2.0242, -0.9184], + [-0.4005, 1.3586, 0.6205, -0.7487], + [ 1.6239, 0.2900, 0.9671, 1.2961]], + + [[-1.1996, -0.2201, -0.9466, -0.7264], + [-0.0313, 0.8284, -0.3588, 1.3522], + [-0.0991, -0.5112, -0.1785, 2.0903], + [-1.3286, -0.9333, -0.1404, 1.2582]]]]) + pool = torch.nn.LPPool2d(2, 4, stride=2) + result = pool(input) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[[[-0.5743, 0.4889, -0.0878, 0.4210, -0.0844], + [ 0.3614, 0.8458, -0.6152, 0.6894, 0.2927], + [-0.0087, 0.1098, 0.1783, -0.6953, 0.5519], + [ 0.3789, -0.0560, -0.4090, -0.1070, -1.0139], + [ 0.9204, 1.0817, -2.6126, 0.4244, 0.3272]]]]) + pool = torch.nn.LPPool2d(10, 3, stride=2, ceil_mode=True) + result = pool(input) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[[[-0.5743, 0.4889, -0.0878, 0.4210, -0.0844], + [ 0.3614, 0.8458, -0.6152, 0.6894, 0.2927], + [-0.0087, 0.1098, 0.1783, -0.6953, 0.5519], + [ 0.3789, -0.0560, -0.4090, -0.1070, -1.0139], + [ 0.9204, 1.0817, -2.6126, 0.4244, 0.3272]]]]) + pool = torch.nn.LPPool2d(10, 3, stride=4) + result = pool(input) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[[[ 0.6430, 0.4511, -1.6757, 1.7116], + [-0.2288, -0.4111, -1.3602, 0.2685], + [ 0.2363, 1.9341, 0.8522, -0.1846], + [ 1.6496, -0.0675, -0.7208, -1.0018]], + + [[-0.3183, 0.8029, -0.4993, 1.0598], + [-0.4952, -0.9536, 0.1954, 0.0551], + [ 1.2257, 0.7517, 0.4063, -1.2151], + [-1.3562, 0.3547, 1.1147, 1.2898]], + + [[ 0.1205, -0.1889, 0.5086, -0.8080], + [ 0.3156, -0.8298, 2.0242, -0.9184], + [-0.4005, 1.3586, 0.6205, -0.7487], + [ 1.6239, 0.2900, 0.9671, 1.2961]], + + [[-1.1996, -0.2201, -0.9466, -0.7264], + [-0.0313, 0.8284, -0.3588, 1.3522], + [-0.0991, -0.5112, -0.1785, 2.0903], + [-1.3286, -0.9333, -0.1404, 1.2582]]]]) + pool = nn.LPPool2d(norm_type=6, kernel_size=2, stride=2, ceil_mode=True) + result = pool(input) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[[[ 0.6430, 0.4511, -1.6757, 1.7116], + [-0.2288, -0.4111, -1.3602, 0.2685], + [ 0.2363, 1.9341, 0.8522, -0.1846], + [ 1.6496, -0.0675, -0.7208, -1.0018]], + + [[-0.3183, 0.8029, -0.4993, 1.0598], + [-0.4952, -0.9536, 0.1954, 0.0551], + [ 1.2257, 0.7517, 0.4063, -1.2151], + [-1.3562, 0.3547, 1.1147, 1.2898]], + + [[ 0.1205, -0.1889, 0.5086, -0.8080], + [ 0.3156, -0.8298, 2.0242, -0.9184], + [-0.4005, 1.3586, 0.6205, -0.7487], + [ 1.6239, 0.2900, 0.9671, 1.2961]], + + [[-1.1996, -0.2201, -0.9466, -0.7264], + [-0.0313, 0.8284, -0.3588, 1.3522], + [-0.0991, -0.5112, -0.1785, 2.0903], + [-1.3286, -0.9333, -0.1404, 1.2582]]]]) + pool = nn.LPPool2d(norm_type=float('inf'), kernel_size=2, stride=2, ceil_mode=True) + result = pool(input) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) diff --git a/tests/test_nn_Softmin.py b/tests/test_nn_Softmin.py new file mode 100644 index 000000000..8b5c43f0b --- /dev/null +++ b/tests/test_nn_Softmin.py @@ -0,0 +1,91 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.Softmin", is_aux_api=True) + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]]) + model = nn.Softmin(-1) + result = model(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]]) + model = nn.Softmin(dim=1) + result = model(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]]) + model = nn.Softmin() + result = model(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 10.0]], + [[1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0]]]) + model = nn.Softmin(dim=None) + result = model(x) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_functional_feature_alpha_dropout.py b/tests/test_nn_functional_feature_alpha_dropout.py new file mode 100644 index 000000000..9a5a60414 --- /dev/null +++ b/tests/test_nn_functional_feature_alpha_dropout.py @@ -0,0 +1,98 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.functional.feature_alpha_dropout") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + result = F.feature_alpha_dropout(x) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + result = F.feature_alpha_dropout(x, 0.5) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + result = F.feature_alpha_dropout(x, p=0.5) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + result = F.feature_alpha_dropout(x, 0.5, True, True) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + result = F.feature_alpha_dropout(input=x, p=0.5, training=True, inplace=True) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +# generated by validate_unittest autofix, based on test_case_5 +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + result = F.feature_alpha_dropout(inplace=True, training=True, p=0.5, input=x) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) diff --git a/tests/test_nn_functional_lp_pool1d.py b/tests/test_nn_functional_lp_pool1d.py new file mode 100644 index 000000000..73da41ba6 --- /dev/null +++ b/tests/test_nn_functional_lp_pool1d.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.functional.lp_pool1d") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[-0.5743, 0.4889, -0.0878, 0.4210, -0.0844], + [0.3614, 0.8458, -0.6152, 0.6894, 0.2927], + [-0.0087, 0.1098, 0.1783, -0.6953, 0.5519], + [0.3789, -0.0560, -0.4090, -0.1070, -1.0139], + [0.9204, 1.0817, -2.6126, 0.4244, 0.3272]]]) + result = torch.nn.functional.lp_pool1d(input, 1, 2) + """ + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[0.6430, 0.4511, -1.6757, 1.7116], + [-0.2288, -0.4111, -1.3602, 0.2685], + [0.2363, 1.9341, 0.8522, -0.1846], + [1.6496, -0.0675, -0.7208, -1.0018]], + + [[-0.3183, 0.8029, -0.4993, 1.0598], + [-0.4952, -0.9536, 0.1954, 0.0551], + [1.2257, 0.7517, 0.4063, -1.2151], + [-1.3562, 0.3547, 1.1147, 1.2898]], + + [[0.1205, -0.1889, 0.5086, -0.8080], + [0.3156, -0.8298, 2.0242, -0.9184], + [-0.4005, 1.3586, 0.6205, -0.7487], + [1.6239, 0.2900, 0.9671, 1.2961]], + + [[-1.1996, -0.2201, -0.9466, -0.7264], + [-0.0313, 0.8284, -0.3588, 1.3522], + [-0.0991, -0.5112, -0.1785, 2.0903], + [-1.3286, -0.9333, -0.1404, 1.2582]]]) + result = torch.nn.functional.lp_pool1d(input, 4, 2, 2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[-0.5743, 0.4889, -0.0878, 0.4210, -0.0844], + [0.3614, 0.8458, -0.6152, 0.6894, 0.2927], + [-0.0087, 0.1098, 0.1783, -0.6953, 0.5519], + [0.3789, -0.0560, -0.4090, -0.1070, -1.0139], + [0.9204, 1.0817, -2.6126, 0.4244, 0.3272]]]) + result = torch.nn.functional.lp_pool1d(input, -float('inf'), 3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[-0.5743, 0.4889, -0.0878, 0.4210, -0.0844], + [0.3614, 0.8458, -0.6152, 0.6894, 0.2927], + [-0.0087, 0.1098, 0.1783, -0.6953, 0.5519], + [0.3789, -0.0560, -0.4090, -0.1070, -1.0139], + [0.9204, 1.0817, -2.6126, 0.4244, 0.3272]]]) + result = torch.nn.functional.lp_pool1d(input=input, norm_type=4, kernel_size=3, stride=2, ceil_mode=True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[-0.5743, 0.4889, -0.0878, 0.4210, -0.0844], + [0.3614, 0.8458, -0.6152, 0.6894, 0.2927], + [-0.0087, 0.1098, 0.1783, -0.6953, 0.5519], + [0.3789, -0.0560, -0.4090, -0.1070, -1.0139], + [0.9204, 1.0817, -2.6126, 0.4244, 0.3272]]]) + result = torch.nn.functional.lp_pool1d(input, float('inf'), 2) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_functional_lp_pool2d.py b/tests/test_nn_functional_lp_pool2d.py new file mode 100644 index 000000000..0badbaa06 --- /dev/null +++ b/tests/test_nn_functional_lp_pool2d.py @@ -0,0 +1,138 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.functional.lp_pool2d") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[[-0.5743, 0.4889, -0.0878, 0.4210, -0.0844], + [ 0.3614, 0.8458, -0.6152, 0.6894, 0.2927], + [-0.0087, 0.1098, 0.1783, -0.6953, 0.5519], + [ 0.3789, -0.0560, -0.4090, -0.1070, -1.0139], + [ 0.9204, 1.0817, -2.6126, 0.4244, 0.3272]]]]) + result = torch.nn.functional.lp_pool2d(input, 1, 3, stride=2) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[[ 0.6430, 0.4511, -1.6757, 1.7116], + [-0.2288, -0.4111, -1.3602, 0.2685], + [ 0.2363, 1.9341, 0.8522, -0.1846], + [ 1.6496, -0.0675, -0.7208, -1.0018]], + + [[-0.3183, 0.8029, -0.4993, 1.0598], + [-0.4952, -0.9536, 0.1954, 0.0551], + [ 1.2257, 0.7517, 0.4063, -1.2151], + [-1.3562, 0.3547, 1.1147, 1.2898]], + + [[ 0.1205, -0.1889, 0.5086, -0.8080], + [ 0.3156, -0.8298, 2.0242, -0.9184], + [-0.4005, 1.3586, 0.6205, -0.7487], + [ 1.6239, 0.2900, 0.9671, 1.2961]], + + [[-1.1996, -0.2201, -0.9466, -0.7264], + [-0.0313, 0.8284, -0.3588, 1.3522], + [-0.0991, -0.5112, -0.1785, 2.0903], + [-1.3286, -0.9333, -0.1404, 1.2582]]]]) + result = torch.nn.functional.lp_pool2d(input, 2, 4, stride=2) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[[-0.5743, 0.4889, -0.0878, 0.4210, -0.0844], + [ 0.3614, 0.8458, -0.6152, 0.6894, 0.2927], + [-0.0087, 0.1098, 0.1783, -0.6953, 0.5519], + [ 0.3789, -0.0560, -0.4090, -0.1070, -1.0139], + [ 0.9204, 1.0817, -2.6126, 0.4244, 0.3272]]]]) + result = torch.nn.functional.lp_pool2d(input, 20, 3, stride=2, ceil_mode=True) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[[-0.5743, 0.4889, -0.0878, 0.4210, -0.0844], + [ 0.3614, 0.8458, -0.6152, 0.6894, 0.2927], + [-0.0087, 0.1098, 0.1783, -0.6953, 0.5519], + [ 0.3789, -0.0560, -0.4090, -0.1070, -1.0139], + [ 0.9204, 1.0817, -2.6126, 0.4244, 0.3272]]]]) + result = torch.nn.functional.lp_pool2d(input, 4, 3, stride=2, ceil_mode=True) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + input = torch.tensor([[[[ 0.6430, 0.4511, -1.6757, 1.7116], + [-0.2288, -0.4111, -1.3602, 0.2685], + [ 0.2363, 1.9341, 0.8522, -0.1846], + [ 1.6496, -0.0675, -0.7208, -1.0018]], + + [[-0.3183, 0.8029, -0.4993, 1.0598], + [-0.4952, -0.9536, 0.1954, 0.0551], + [ 1.2257, 0.7517, 0.4063, -1.2151], + [-1.3562, 0.3547, 1.1147, 1.2898]], + + [[ 0.1205, -0.1889, 0.5086, -0.8080], + [ 0.3156, -0.8298, 2.0242, -0.9184], + [-0.4005, 1.3586, 0.6205, -0.7487], + [ 1.6239, 0.2900, 0.9671, 1.2961]], + + [[-1.1996, -0.2201, -0.9466, -0.7264], + [-0.0313, 0.8284, -0.3588, 1.3522], + [-0.0991, -0.5112, -0.1785, 2.0903], + [-1.3286, -0.9333, -0.1404, 1.2582]]]]) + result = F.lp_pool2d(input=input, norm_type=2, kernel_size=2, stride=2, ceil_mode=True) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[[-0.5743, 0.4889, -0.0878, 0.4210, -0.0844], + [ 0.3614, 0.8458, -0.6152, 0.6894, 0.2927], + [-0.0087, 0.1098, 0.1783, -0.6953, 0.5519], + [ 0.3789, -0.0560, -0.4090, -0.1070, -1.0139], + [ 0.9204, 1.0817, -2.6126, 0.4244, 0.3272]]]]) + result = torch.nn.functional.lp_pool2d(input, float('inf'), 3, stride=2) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) diff --git a/tests/test_nn_functional_threshold_.py b/tests/test_nn_functional_threshold_.py new file mode 100644 index 000000000..ff248f5d2 --- /dev/null +++ b/tests/test_nn_functional_threshold_.py @@ -0,0 +1,84 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.functional.threshold_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + result = nn.functional.threshold_(x, 0.5, 0.0) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + result = nn.functional.threshold_(x, threshold=0.5, value=0.0) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + result = nn.functional.threshold_(x, value=0.0, threshold=0.5) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + result = nn.functional.threshold_(x, 0.5, 0.0) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + result = nn.functional.threshold_(x, threshold=0.5, value=0.1) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_utils_parametrizations_weight_norm.py b/tests/test_nn_utils_parametrizations_weight_norm.py new file mode 100644 index 000000000..5943f7b79 --- /dev/null +++ b/tests/test_nn_utils_parametrizations_weight_norm.py @@ -0,0 +1,101 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.utils.parametrizations.weight_norm") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + m = torch.nn.utils.parametrizations.weight_norm(nn.Linear(20, 40), name='weight') + a = torch.ones(20) + result = m(a) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + m = torch.nn.utils.parametrizations.weight_norm(nn.Linear(20, 40), name='weight', dim=0) + a = torch.ones(20) + result = m(a) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +# generated by validate_unittest autofix, based on test_case_2 +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + m = torch.nn.utils.parametrizations.weight_norm(nn.Linear(20, 40), 'weight', 0) + a = torch.ones(20) + result = m(a) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +# generated by validate_unittest autofix, based on test_case_2 +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + m = torch.nn.utils.parametrizations.weight_norm(module=nn.Linear(20, 40), name='weight', dim=0) + a = torch.ones(20) + result = m(a) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +# generated by validate_unittest autofix, based on test_case_2 +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + m = torch.nn.utils.parametrizations.weight_norm(dim=0, name='weight', module=nn.Linear(20, 40)) + a = torch.ones(20) + result = m(a) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +# generated by validate_unittest autofix, based on test_case_2 +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + m = torch.nn.utils.parametrizations.weight_norm(nn.Linear(20, 40)) + a = torch.ones(20) + result = m(a) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) diff --git a/tests/test_optim_NAdam.py b/tests/test_optim_NAdam.py new file mode 100644 index 000000000..9ca3a2db1 --- /dev/null +++ b/tests/test_optim_NAdam.py @@ -0,0 +1,88 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase +from optimizer_helper import generate_optimizer_test_code + +obj = APIBase("torch.optim.NAdam") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code("torch.optim.NAdam(conv.parameters(), eps=1e-7)") + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.NAdam(conv.parameters(), betas=(0.5, 0.99))" + ) + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.NAdam(conv.parameters(), weight_decay=0.01)" + ) + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.NAdam(params=conv.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.)" + ) + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code("torch.optim.NAdam(conv.parameters())") + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.NAdam(conv.parameters(), 0.001, (0.9, 0.999), 1e-08, 0.)" + ) + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.NAdam(betas=(0.9, 0.999), lr=0.001, params=conv.parameters(), eps=1e-08, weight_decay=0.)" + ) + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.NAdam(betas=(0.9, 0.999), lr=0.001, params=conv.parameters(), eps=1e-08, weight_decay=0., momentum_decay=0.005)" + ) + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) diff --git a/tests/test_optim_RAdam.py b/tests/test_optim_RAdam.py new file mode 100644 index 000000000..36d2008b4 --- /dev/null +++ b/tests/test_optim_RAdam.py @@ -0,0 +1,79 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase +from optimizer_helper import generate_optimizer_test_code + +obj = APIBase("torch.optim.RAdam") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code("torch.optim.RAdam(conv.parameters(), eps=1e-7)") + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.RAdam(conv.parameters(), betas=(0.5, 0.99))" + ) + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.RAdam(conv.parameters(), weight_decay=0.01)" + ) + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.RAdam(params=conv.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.)" + ) + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code("torch.optim.RAdam(conv.parameters())") + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.RAdam(conv.parameters(), 0.001, (0.9, 0.999), 1e-08, 0.)" + ) + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.RAdam(betas=(0.9, 0.999), lr=0.001, params=conv.parameters(), eps=1e-08, weight_decay=0.)" + ) + ) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) From 68fa44f27bbd80b1a2080303a793995ffc13a188 Mon Sep 17 00:00:00 2001 From: starfall <1186454801@qq.com> Date: Tue, 8 Oct 2024 19:35:53 +0800 Subject: [PATCH 02/10] add skip --- tests/test_nn_LPPool1d.py | 5 +++++ tests/test_nn_LPPool2d.py | 5 +++++ tests/test_nn_functional_lp_pool1d.py | 9 +++++++++ tests/test_nn_functional_lp_pool2d.py | 5 +++++ 4 files changed, 24 insertions(+) diff --git a/tests/test_nn_LPPool1d.py b/tests/test_nn_LPPool1d.py index fb89d91a7..96af8aace 100644 --- a/tests/test_nn_LPPool1d.py +++ b/tests/test_nn_LPPool1d.py @@ -14,6 +14,7 @@ import textwrap +import pytest from apibase import APIBase obj = APIBase("torch.nn.LPPool1d") @@ -67,6 +68,10 @@ def test_case_2(): obj.run(pytorch_code, ["result"]) +@pytest.mark.skipif( + condition=True, + reason="`lp_pool` in PyTorch has a wrong implementation which will return a tensor full of 1.", +) def test_case_3(): pytorch_code = textwrap.dedent( """ diff --git a/tests/test_nn_LPPool2d.py b/tests/test_nn_LPPool2d.py index e0e0f7d2c..9e1e04bf9 100644 --- a/tests/test_nn_LPPool2d.py +++ b/tests/test_nn_LPPool2d.py @@ -14,6 +14,7 @@ import textwrap +import pytest from apibase import APIBase obj = APIBase("torch.nn.LPPool2d") @@ -132,6 +133,10 @@ def test_case_5(): obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) +@pytest.mark.skipif( + condition=True, + reason="`lp_pool` in PyTorch has a wrong implementation which will return a tensor full of 1.", +) def test_case_6(): pytorch_code = textwrap.dedent( """ diff --git a/tests/test_nn_functional_lp_pool1d.py b/tests/test_nn_functional_lp_pool1d.py index 73da41ba6..6e8d7223d 100644 --- a/tests/test_nn_functional_lp_pool1d.py +++ b/tests/test_nn_functional_lp_pool1d.py @@ -14,6 +14,7 @@ import textwrap +import pytest from apibase import APIBase obj = APIBase("torch.nn.functional.lp_pool1d") @@ -63,6 +64,10 @@ def test_case_2(): obj.run(pytorch_code, ["result"]) +@pytest.mark.skipif( + condition=True, + reason="`lp_pool` in PyTorch has a wrong implementation which will return a tensor full of 1.", +) def test_case_3(): pytorch_code = textwrap.dedent( """ @@ -93,6 +98,10 @@ def test_case_4(): obj.run(pytorch_code, ["result"]) +@pytest.mark.skipif( + condition=True, + reason="`lp_pool` in PyTorch has a wrong implementation which will return a tensor full of 1.", +) def test_case_5(): pytorch_code = textwrap.dedent( """ diff --git a/tests/test_nn_functional_lp_pool2d.py b/tests/test_nn_functional_lp_pool2d.py index 0badbaa06..50b04b22e 100644 --- a/tests/test_nn_functional_lp_pool2d.py +++ b/tests/test_nn_functional_lp_pool2d.py @@ -14,6 +14,7 @@ import textwrap +import pytest from apibase import APIBase obj = APIBase("torch.nn.functional.lp_pool2d") @@ -123,6 +124,10 @@ def test_case_5(): obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) +@pytest.mark.skipif( + condition=True, + reason="`lp_pool` in PyTorch has a wrong implementation which will return a tensor full of 1.", +) def test_case_6(): pytorch_code = textwrap.dedent( """ From e590a24995e3cce53702a92c97f0d11774dc5bf0 Mon Sep 17 00:00:00 2001 From: starfall <1186454801@qq.com> Date: Fri, 25 Oct 2024 15:33:37 +0800 Subject: [PATCH 03/10] update Softmin --- paconvert/api_matcher.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 3e36164a0..654510aa5 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -3916,7 +3916,7 @@ def generate_code(self, kwargs): class SoftminMatcher(SoftmaxMatcher): def generate_code(self, kwargs): - self.paddle_api = "paddle.nn.Softmin" + self.paddle_api = "paddle_aux.Softmin" return super().generate_code(kwargs) def generate_aux_code(self): @@ -3938,7 +3938,6 @@ def forward(self,x): class Softmin(paddle.nn.Softmax): def forward(self, x): return super().forward(-x) - setattr(paddle.nn, 'Softmin', Softmin) """ ) return CODE_TEMPLATE From ff99c665644090e737b6a2c6b6131f1feb2a82d5 Mon Sep 17 00:00:00 2001 From: starfall <1186454801@qq.com> Date: Thu, 7 Nov 2024 10:47:00 +0800 Subject: [PATCH 04/10] update --- tests/test_optim_NAdam.py | 16 ++++++++-------- tests/test_optim_RAdam.py | 14 +++++++------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/test_optim_NAdam.py b/tests/test_optim_NAdam.py index 9ca3a2db1..3e128db2d 100644 --- a/tests/test_optim_NAdam.py +++ b/tests/test_optim_NAdam.py @@ -24,7 +24,7 @@ def test_case_1(): pytorch_code = textwrap.dedent( generate_optimizer_test_code("torch.optim.NAdam(conv.parameters(), eps=1e-7)") ) - obj.run(pytorch_code, ["result"], rtol=1.0e-5) + obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) def test_case_2(): @@ -33,7 +33,7 @@ def test_case_2(): "torch.optim.NAdam(conv.parameters(), betas=(0.5, 0.99))" ) ) - obj.run(pytorch_code, ["result"], rtol=1.0e-5) + obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) def test_case_3(): @@ -42,7 +42,7 @@ def test_case_3(): "torch.optim.NAdam(conv.parameters(), weight_decay=0.01)" ) ) - obj.run(pytorch_code, ["result"], rtol=1.0e-5) + obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) def test_case_4(): @@ -51,14 +51,14 @@ def test_case_4(): "torch.optim.NAdam(params=conv.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.)" ) ) - obj.run(pytorch_code, ["result"], rtol=1.0e-5) + obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) def test_case_5(): pytorch_code = textwrap.dedent( generate_optimizer_test_code("torch.optim.NAdam(conv.parameters())") ) - obj.run(pytorch_code, ["result"], rtol=1.0e-5) + obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) def test_case_6(): @@ -67,7 +67,7 @@ def test_case_6(): "torch.optim.NAdam(conv.parameters(), 0.001, (0.9, 0.999), 1e-08, 0.)" ) ) - obj.run(pytorch_code, ["result"], rtol=1.0e-5) + obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) def test_case_7(): @@ -76,7 +76,7 @@ def test_case_7(): "torch.optim.NAdam(betas=(0.9, 0.999), lr=0.001, params=conv.parameters(), eps=1e-08, weight_decay=0.)" ) ) - obj.run(pytorch_code, ["result"], rtol=1.0e-5) + obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) def test_case_8(): @@ -85,4 +85,4 @@ def test_case_8(): "torch.optim.NAdam(betas=(0.9, 0.999), lr=0.001, params=conv.parameters(), eps=1e-08, weight_decay=0., momentum_decay=0.005)" ) ) - obj.run(pytorch_code, ["result"], rtol=1.0e-5) + obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) diff --git a/tests/test_optim_RAdam.py b/tests/test_optim_RAdam.py index 36d2008b4..34fc822e0 100644 --- a/tests/test_optim_RAdam.py +++ b/tests/test_optim_RAdam.py @@ -24,7 +24,7 @@ def test_case_1(): pytorch_code = textwrap.dedent( generate_optimizer_test_code("torch.optim.RAdam(conv.parameters(), eps=1e-7)") ) - obj.run(pytorch_code, ["result"], rtol=1.0e-5) + obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) def test_case_2(): @@ -33,7 +33,7 @@ def test_case_2(): "torch.optim.RAdam(conv.parameters(), betas=(0.5, 0.99))" ) ) - obj.run(pytorch_code, ["result"], rtol=1.0e-5) + obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) def test_case_3(): @@ -42,7 +42,7 @@ def test_case_3(): "torch.optim.RAdam(conv.parameters(), weight_decay=0.01)" ) ) - obj.run(pytorch_code, ["result"], rtol=1.0e-5) + obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) def test_case_4(): @@ -51,14 +51,14 @@ def test_case_4(): "torch.optim.RAdam(params=conv.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.)" ) ) - obj.run(pytorch_code, ["result"], rtol=1.0e-5) + obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) def test_case_5(): pytorch_code = textwrap.dedent( generate_optimizer_test_code("torch.optim.RAdam(conv.parameters())") ) - obj.run(pytorch_code, ["result"], rtol=1.0e-5) + obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) def test_case_6(): @@ -67,7 +67,7 @@ def test_case_6(): "torch.optim.RAdam(conv.parameters(), 0.001, (0.9, 0.999), 1e-08, 0.)" ) ) - obj.run(pytorch_code, ["result"], rtol=1.0e-5) + obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) def test_case_7(): @@ -76,4 +76,4 @@ def test_case_7(): "torch.optim.RAdam(betas=(0.9, 0.999), lr=0.001, params=conv.parameters(), eps=1e-08, weight_decay=0.)" ) ) - obj.run(pytorch_code, ["result"], rtol=1.0e-5) + obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) From 340d69d53e7c737c14f60398c9ae1adacd0cc76e Mon Sep 17 00:00:00 2001 From: starfall <1186454801@qq.com> Date: Thu, 7 Nov 2024 11:04:53 +0800 Subject: [PATCH 05/10] update --- tests/test_optim_NAdam.py | 86 +++++++++++++++++++++++++++++++++++---- tests/test_optim_RAdam.py | 84 ++++++++++++++++++++++++++++++++++---- 2 files changed, 155 insertions(+), 15 deletions(-) diff --git a/tests/test_optim_NAdam.py b/tests/test_optim_NAdam.py index 3e128db2d..a98bc59d7 100644 --- a/tests/test_optim_NAdam.py +++ b/tests/test_optim_NAdam.py @@ -24,7 +24,7 @@ def test_case_1(): pytorch_code = textwrap.dedent( generate_optimizer_test_code("torch.optim.NAdam(conv.parameters(), eps=1e-7)") ) - obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) def test_case_2(): @@ -33,7 +33,7 @@ def test_case_2(): "torch.optim.NAdam(conv.parameters(), betas=(0.5, 0.99))" ) ) - obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) def test_case_3(): @@ -42,7 +42,7 @@ def test_case_3(): "torch.optim.NAdam(conv.parameters(), weight_decay=0.01)" ) ) - obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) def test_case_4(): @@ -51,14 +51,14 @@ def test_case_4(): "torch.optim.NAdam(params=conv.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.)" ) ) - obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) def test_case_5(): pytorch_code = textwrap.dedent( generate_optimizer_test_code("torch.optim.NAdam(conv.parameters())") ) - obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) def test_case_6(): @@ -67,7 +67,7 @@ def test_case_6(): "torch.optim.NAdam(conv.parameters(), 0.001, (0.9, 0.999), 1e-08, 0.)" ) ) - obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) def test_case_7(): @@ -76,7 +76,7 @@ def test_case_7(): "torch.optim.NAdam(betas=(0.9, 0.999), lr=0.001, params=conv.parameters(), eps=1e-08, weight_decay=0.)" ) ) - obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) def test_case_8(): @@ -85,4 +85,74 @@ def test_case_8(): "torch.optim.NAdam(betas=(0.9, 0.999), lr=0.001, params=conv.parameters(), eps=1e-08, weight_decay=0., momentum_decay=0.005)" ) ) - obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.NAdam(conv.parameters(), eps=1e-7, decoupled_weight_decay=True)" + ) + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle do not support `decoupled_weight_decay`", + ) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.NAdam(conv.parameters(), eps=1e-7, foreach=True)" + ) + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle do not support `foreach`", + ) + + +def test_case_11(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.NAdam(conv.parameters(), eps=1e-7, maximize=True)" + ) + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle do not support `maximize`", + ) + + +def test_case_12(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.NAdam(conv.parameters(), eps=1e-7, capturable=True)" + ) + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle do not support `capturable`", + ) + + +def test_case_13(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.NAdam(conv.parameters(), eps=1e-7, differentiable=True)" + ) + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle do not support `differentiable`", + ) diff --git a/tests/test_optim_RAdam.py b/tests/test_optim_RAdam.py index 34fc822e0..fbb5ae05b 100644 --- a/tests/test_optim_RAdam.py +++ b/tests/test_optim_RAdam.py @@ -24,7 +24,7 @@ def test_case_1(): pytorch_code = textwrap.dedent( generate_optimizer_test_code("torch.optim.RAdam(conv.parameters(), eps=1e-7)") ) - obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) def test_case_2(): @@ -33,7 +33,7 @@ def test_case_2(): "torch.optim.RAdam(conv.parameters(), betas=(0.5, 0.99))" ) ) - obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) def test_case_3(): @@ -42,7 +42,7 @@ def test_case_3(): "torch.optim.RAdam(conv.parameters(), weight_decay=0.01)" ) ) - obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) def test_case_4(): @@ -51,14 +51,14 @@ def test_case_4(): "torch.optim.RAdam(params=conv.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.)" ) ) - obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) def test_case_5(): pytorch_code = textwrap.dedent( generate_optimizer_test_code("torch.optim.RAdam(conv.parameters())") ) - obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) def test_case_6(): @@ -67,7 +67,7 @@ def test_case_6(): "torch.optim.RAdam(conv.parameters(), 0.001, (0.9, 0.999), 1e-08, 0.)" ) ) - obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) def test_case_7(): @@ -76,4 +76,74 @@ def test_case_7(): "torch.optim.RAdam(betas=(0.9, 0.999), lr=0.001, params=conv.parameters(), eps=1e-08, weight_decay=0.)" ) ) - obj.run(pytorch_code, ["result"], unsupport=True, rtol=1.0e-5) + obj.run(pytorch_code, ["result"], rtol=1.0e-5) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.RAdam(conv.parameters(), eps=1e-7, differentiable=True)" + ) + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle do not support `differentiable`", + ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.RAdam(conv.parameters(), eps=1e-7, decoupled_weight_decay=True)" + ) + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle do not support `decoupled_weight_decay`", + ) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.RAdam(conv.parameters(), eps=1e-7, foreach=True)" + ) + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle do not support `foreach`", + ) + + +def test_case_11(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.RAdam(conv.parameters(), eps=1e-7, maximize=True)" + ) + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle do not support `maximize`", + ) + + +def test_case_12(): + pytorch_code = textwrap.dedent( + generate_optimizer_test_code( + "torch.optim.RAdam(conv.parameters(), eps=1e-7, capturable=True)" + ) + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle do not support `capturable`", + ) From a6e5ed5d1ca0bcc6cfcddde869db1409768fbe19 Mon Sep 17 00:00:00 2001 From: starfall <1186454801@qq.com> Date: Mon, 25 Nov 2024 19:01:50 +0800 Subject: [PATCH 06/10] fix --- paconvert/api_matcher.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 18723da79..f274e781f 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -4306,6 +4306,7 @@ def generate_code(self, kwargs): class SoftminMatcher(SoftmaxMatcher): def generate_code(self, kwargs): self.paddle_api = "paddle_aux.Softmin" + self.write_aux_code() return super().generate_code(kwargs) def generate_aux_code(self): @@ -4876,7 +4877,7 @@ def generate_code(self, kwargs): class PositiveMatcher(BaseMatcher): def generate_aux_code(self): CODE_TEMPLATE = textwrap.dedent( - """ + """ def positive(x): if x.dtype != paddle.bool: return x @@ -5233,7 +5234,7 @@ def get_scalable_var(self): if not (arg_name.startswith("*") and len(arg_name) > 1): return None return arg_name[1:] - + def get_paddle_nodes(self, args, kwargs): var_arg_name = self.get_scalable_var() dest_var_arg_name = self.api_mapping.get("kwargs_change", {}).get( @@ -5251,7 +5252,7 @@ def get_paddle_nodes(self, args, kwargs): return ast.parse(code).body -class ScalableVarMatcher(BaseMatcher): +class ScalableVarMatcher(BaseMatcher): def get_scalable_var(self): args_list = self.api_mapping.get("args_list", []) if len(args_list) != 1: From 555dbd312473376d3de7b7a61d0ac7767f53f9b7 Mon Sep 17 00:00:00 2001 From: starfall <1186454801@qq.com> Date: Mon, 25 Nov 2024 19:09:35 +0800 Subject: [PATCH 07/10] fix codestyle check --- paconvert/api_alias_mapping.json | 2 +- paconvert/api_mapping.json | 24 +++++++++---------- tests/test_Tensor_scatter_reduce.py | 1 + tests/test_block_diag.py | 4 ++-- ...est_nn_functional_feature_alpha_dropout.py | 2 +- tests/test_scatter_reduce.py | 1 + 6 files changed, 18 insertions(+), 16 deletions(-) diff --git a/paconvert/api_alias_mapping.json b/paconvert/api_alias_mapping.json index 4ab131099..7c60838ea 100644 --- a/paconvert/api_alias_mapping.json +++ b/paconvert/api_alias_mapping.json @@ -48,8 +48,8 @@ "torch.bilinear": "torch.nn.functional.bilinear", "torch.celu_": "torch.nn.functional.celu_", "torch.channel_shuffle": "torch.nn.functional.channel_shuffle", - "torch.concatenate": "torch.cat", "torch.clip": "torch.clamp", + "torch.concatenate": "torch.cat", "torch.conv1d": "torch.nn.functional.conv1d", "torch.conv2d": "torch.nn.functional.conv2d", "torch.conv3d": "torch.nn.functional.conv3d", diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 89f4eedeb..578685eba 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -3192,7 +3192,7 @@ ] }, "torch.Tensor.positive": { - "Matcher": "PositiveMatcher" + "Matcher": "PositiveMatcher" }, "torch.Tensor.pow": { "Matcher": "GenericMatcher", @@ -3571,7 +3571,7 @@ } }, "torch.Tensor.scatter_reduce": { - "Matcher": "ScatterReduceMatcher", + "Matcher": "ScatterReduceMatcher", "paddle_api": "paddle.Tensor.put_along_axis", "min_input_args": 3, "args_list": [ @@ -7836,6 +7836,16 @@ "axis": 0 } }, + "torch.float_power": { + "Matcher": "FloatPowerMatcher", + "min_input_args": 2, + "args_list": [ + "input", + "exponent", + "*", + "out" + ] + }, "torch.floor": { "Matcher": "GenericMatcher", "paddle_api": "paddle.floor", @@ -7849,16 +7859,6 @@ "input": "x" } }, - "torch.float_power": { - "Matcher": "FloatPowerMatcher", - "min_input_args": 2, - "args_list": [ - "input", - "exponent", - "*", - "out" - ] - }, "torch.floor_divide": { "Matcher": "Num2TensorBinaryMatcher", "paddle_api": "paddle.floor_divide", diff --git a/tests/test_Tensor_scatter_reduce.py b/tests/test_Tensor_scatter_reduce.py index 6d6e27fb5..2b22eb15d 100644 --- a/tests/test_Tensor_scatter_reduce.py +++ b/tests/test_Tensor_scatter_reduce.py @@ -111,6 +111,7 @@ def test_case_7(): ) obj.run(pytorch_code, ["result"]) + def test_case_8(): pytorch_code = textwrap.dedent( """ diff --git a/tests/test_block_diag.py b/tests/test_block_diag.py index bfc343d84..7589caff5 100644 --- a/tests/test_block_diag.py +++ b/tests/test_block_diag.py @@ -41,8 +41,8 @@ def test_case_2(): A = torch.tensor([[4], [3], [2]]) B = torch.tensor([7, 6, 5]) C = torch.tensor(1) - result = torch.block_diag(torch.tensor([[4], [3], [2]]), - torch.tensor([7, 6, 5]), + result = torch.block_diag(torch.tensor([[4], [3], [2]]), + torch.tensor([7, 6, 5]), torch.tensor(1)) """ ) diff --git a/tests/test_nn_functional_feature_alpha_dropout.py b/tests/test_nn_functional_feature_alpha_dropout.py index 0338c51d5..0ac7d82fd 100644 --- a/tests/test_nn_functional_feature_alpha_dropout.py +++ b/tests/test_nn_functional_feature_alpha_dropout.py @@ -74,4 +74,4 @@ def test_case_5(): obj.run( pytorch_code, ["result"], - ) \ No newline at end of file + ) diff --git a/tests/test_scatter_reduce.py b/tests/test_scatter_reduce.py index 9ffc30f0d..e7a430229 100644 --- a/tests/test_scatter_reduce.py +++ b/tests/test_scatter_reduce.py @@ -111,6 +111,7 @@ def test_case_7(): ) obj.run(pytorch_code, ["result"]) + def test_case_8(): pytorch_code = textwrap.dedent( """ From 0f8d128f0d10bb8e54b2b66554e52f8afc5d6724 Mon Sep 17 00:00:00 2001 From: starfall <1186454801@qq.com> Date: Tue, 26 Nov 2024 01:12:06 +0800 Subject: [PATCH 08/10] fix --- paconvert/api_mapping.json | 4 +- tests/test_nn_AdaptiveLogSoftmaxWithLoss.py | 16 ++++++ tests/test_nn_LPPool1d.py | 62 +++++++++++++++++++++ tests/test_nn_LPPool2d.py | 62 +++++++++++++++++++++ tests/test_nn_Softmax.py | 15 +++++ tests/test_nn_Softmin.py | 15 +++++ tests/test_nn_functional_lp_pool1d.py | 15 +++++ tests/test_nn_functional_lp_pool2d.py | 30 ++++++++++ tests/test_nn_functional_threshold_.py | 26 +++++++++ 9 files changed, 243 insertions(+), 2 deletions(-) diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 578685eba..b34683cbe 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -11955,7 +11955,7 @@ "dim": "axis" }, "paddle_default_kwargs": { - "axis": 0 + "axis": null } }, "torch.nn.Softmax2d": { @@ -11976,7 +11976,7 @@ "dim": "axis" }, "paddle_default_kwargs": { - "axis": 0 + "axis": null }, "min_input_args": 0 }, diff --git a/tests/test_nn_AdaptiveLogSoftmaxWithLoss.py b/tests/test_nn_AdaptiveLogSoftmaxWithLoss.py index e5d82f25e..a628d9d72 100644 --- a/tests/test_nn_AdaptiveLogSoftmaxWithLoss.py +++ b/tests/test_nn_AdaptiveLogSoftmaxWithLoss.py @@ -81,3 +81,19 @@ def test_case_4(): """ ) obj.run(pytorch_code, ["out", "loss"], check_value=False) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[ 0.9368637 , -0.0361056 , -0.98917043, 0.06605113, 1.5254455 ], + [-1.0518035 , -1.0024613 , 0.18699688, -0.35807893, 0.25628588], + [-0.900478 , -0.41495147, 0.84707606, -1.7883497 , 1.3243382 ]]) + target = torch.tensor([1, 1, 1]) + asfm = torch.nn.AdaptiveLogSoftmaxWithLoss(n_classes=8, in_features=5, div_value=3.8, cutoffs=[5], head_bias=True) + out, loss = asfm(input,target) + """ + ) + obj.run(pytorch_code, ["out", "loss"], check_value=False) diff --git a/tests/test_nn_LPPool1d.py b/tests/test_nn_LPPool1d.py index 96af8aace..435fc258a 100644 --- a/tests/test_nn_LPPool1d.py +++ b/tests/test_nn_LPPool1d.py @@ -135,3 +135,65 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[[ 0.6430, 0.4511, -1.6757, 1.7116], + [-0.2288, -0.4111, -1.3602, 0.2685], + [ 0.2363, 1.9341, 0.8522, -0.1846], + [ 1.6496, -0.0675, -0.7208, -1.0018]], + + [[-0.3183, 0.8029, -0.4993, 1.0598], + [-0.4952, -0.9536, 0.1954, 0.0551], + [ 1.2257, 0.7517, 0.4063, -1.2151], + [-1.3562, 0.3547, 1.1147, 1.2898]], + + [[ 0.1205, -0.1889, 0.5086, -0.8080], + [ 0.3156, -0.8298, 2.0242, -0.9184], + [-0.4005, 1.3586, 0.6205, -0.7487], + [ 1.6239, 0.2900, 0.9671, 1.2961]], + + [[-1.1996, -0.2201, -0.9466, -0.7264], + [-0.0313, 0.8284, -0.3588, 1.3522], + [-0.0991, -0.5112, -0.1785, 2.0903], + [-1.3286, -0.9333, -0.1404, 1.2582]]]) + pool = nn.LPPool1d(2, 2, 2, True) + result = pool(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[[ 0.6430, 0.4511, -1.6757, 1.7116], + [-0.2288, -0.4111, -1.3602, 0.2685], + [ 0.2363, 1.9341, 0.8522, -0.1846], + [ 1.6496, -0.0675, -0.7208, -1.0018]], + + [[-0.3183, 0.8029, -0.4993, 1.0598], + [-0.4952, -0.9536, 0.1954, 0.0551], + [ 1.2257, 0.7517, 0.4063, -1.2151], + [-1.3562, 0.3547, 1.1147, 1.2898]], + + [[ 0.1205, -0.1889, 0.5086, -0.8080], + [ 0.3156, -0.8298, 2.0242, -0.9184], + [-0.4005, 1.3586, 0.6205, -0.7487], + [ 1.6239, 0.2900, 0.9671, 1.2961]], + + [[-1.1996, -0.2201, -0.9466, -0.7264], + [-0.0313, 0.8284, -0.3588, 1.3522], + [-0.0991, -0.5112, -0.1785, 2.0903], + [-1.3286, -0.9333, -0.1404, 1.2582]]]) + pool = nn.LPPool1d(kernel_size=2, stride=2, ceil_mode=True, norm_type=2) + result = pool(input) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_LPPool2d.py b/tests/test_nn_LPPool2d.py index 9e1e04bf9..1c30c8c10 100644 --- a/tests/test_nn_LPPool2d.py +++ b/tests/test_nn_LPPool2d.py @@ -166,3 +166,65 @@ def test_case_6(): """ ) obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[[[ 0.6430, 0.4511, -1.6757, 1.7116], + [-0.2288, -0.4111, -1.3602, 0.2685], + [ 0.2363, 1.9341, 0.8522, -0.1846], + [ 1.6496, -0.0675, -0.7208, -1.0018]], + + [[-0.3183, 0.8029, -0.4993, 1.0598], + [-0.4952, -0.9536, 0.1954, 0.0551], + [ 1.2257, 0.7517, 0.4063, -1.2151], + [-1.3562, 0.3547, 1.1147, 1.2898]], + + [[ 0.1205, -0.1889, 0.5086, -0.8080], + [ 0.3156, -0.8298, 2.0242, -0.9184], + [-0.4005, 1.3586, 0.6205, -0.7487], + [ 1.6239, 0.2900, 0.9671, 1.2961]], + + [[-1.1996, -0.2201, -0.9466, -0.7264], + [-0.0313, 0.8284, -0.3588, 1.3522], + [-0.0991, -0.5112, -0.1785, 2.0903], + [-1.3286, -0.9333, -0.1404, 1.2582]]]]) + pool = nn.LPPool2d(6, 2, 2, True) + result = pool(input) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[[[ 0.6430, 0.4511, -1.6757, 1.7116], + [-0.2288, -0.4111, -1.3602, 0.2685], + [ 0.2363, 1.9341, 0.8522, -0.1846], + [ 1.6496, -0.0675, -0.7208, -1.0018]], + + [[-0.3183, 0.8029, -0.4993, 1.0598], + [-0.4952, -0.9536, 0.1954, 0.0551], + [ 1.2257, 0.7517, 0.4063, -1.2151], + [-1.3562, 0.3547, 1.1147, 1.2898]], + + [[ 0.1205, -0.1889, 0.5086, -0.8080], + [ 0.3156, -0.8298, 2.0242, -0.9184], + [-0.4005, 1.3586, 0.6205, -0.7487], + [ 1.6239, 0.2900, 0.9671, 1.2961]], + + [[-1.1996, -0.2201, -0.9466, -0.7264], + [-0.0313, 0.8284, -0.3588, 1.3522], + [-0.0991, -0.5112, -0.1785, 2.0903], + [-1.3286, -0.9333, -0.1404, 1.2582]]]]) + pool = nn.LPPool2d(kernel_size=2, stride=2, ceil_mode=True, norm_type=6) + result = pool(input) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) diff --git a/tests/test_nn_Softmax.py b/tests/test_nn_Softmax.py index 646e817c5..edb125b6d 100644 --- a/tests/test_nn_Softmax.py +++ b/tests/test_nn_Softmax.py @@ -89,3 +89,18 @@ def test_case_4(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 10.0]]) + model = nn.Softmax() + result = model(x) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_Softmin.py b/tests/test_nn_Softmin.py index 8b5c43f0b..79da7f890 100644 --- a/tests/test_nn_Softmin.py +++ b/tests/test_nn_Softmin.py @@ -89,3 +89,18 @@ def test_case_4(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [7.0, 8.0, 8.0, 10.0]]) + model = nn.Softmin() + result = model(x) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_functional_lp_pool1d.py b/tests/test_nn_functional_lp_pool1d.py index 6e8d7223d..681b691ba 100644 --- a/tests/test_nn_functional_lp_pool1d.py +++ b/tests/test_nn_functional_lp_pool1d.py @@ -115,3 +115,18 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[-0.5743, 0.4889, -0.0878, 0.4210, -0.0844], + [0.3614, 0.8458, -0.6152, 0.6894, 0.2927], + [-0.0087, 0.1098, 0.1783, -0.6953, 0.5519], + [0.3789, -0.0560, -0.4090, -0.1070, -1.0139], + [0.9204, 1.0817, -2.6126, 0.4244, 0.3272]]]) + result = torch.nn.functional.lp_pool1d(input=input, kernel_size=3, stride=2, ceil_mode=True, norm_type=4) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_functional_lp_pool2d.py b/tests/test_nn_functional_lp_pool2d.py index 50b04b22e..6bb6f7e11 100644 --- a/tests/test_nn_functional_lp_pool2d.py +++ b/tests/test_nn_functional_lp_pool2d.py @@ -141,3 +141,33 @@ def test_case_6(): """ ) obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn.functional as F + input = torch.tensor([[[[ 0.6430, 0.4511, -1.6757, 1.7116], + [-0.2288, -0.4111, -1.3602, 0.2685], + [ 0.2363, 1.9341, 0.8522, -0.1846], + [ 1.6496, -0.0675, -0.7208, -1.0018]], + + [[-0.3183, 0.8029, -0.4993, 1.0598], + [-0.4952, -0.9536, 0.1954, 0.0551], + [ 1.2257, 0.7517, 0.4063, -1.2151], + [-1.3562, 0.3547, 1.1147, 1.2898]], + + [[ 0.1205, -0.1889, 0.5086, -0.8080], + [ 0.3156, -0.8298, 2.0242, -0.9184], + [-0.4005, 1.3586, 0.6205, -0.7487], + [ 1.6239, 0.2900, 0.9671, 1.2961]], + + [[-1.1996, -0.2201, -0.9466, -0.7264], + [-0.0313, 0.8284, -0.3588, 1.3522], + [-0.0991, -0.5112, -0.1785, 2.0903], + [-1.3286, -0.9333, -0.1404, 1.2582]]]]) + result = F.lp_pool2d(input=input, kernel_size=2, stride=2, ceil_mode=True, norm_type=2) + """ + ) + obj.run(pytorch_code, ["result"], atol=1e-05, rtol=1e-06) diff --git a/tests/test_nn_functional_threshold_.py b/tests/test_nn_functional_threshold_.py index ff248f5d2..358e29bf8 100644 --- a/tests/test_nn_functional_threshold_.py +++ b/tests/test_nn_functional_threshold_.py @@ -82,3 +82,29 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + result = nn.functional.threshold_(input=x, threshold=0.5, value=0.1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + result = nn.functional.threshold_(threshold=0.8, input=x, value=0.1) + """ + ) + obj.run(pytorch_code, ["result"]) From 7eb8406b36f95457ac8201a15affbe17479bb8af Mon Sep 17 00:00:00 2001 From: starfall <1186454801@qq.com> Date: Tue, 26 Nov 2024 16:11:46 +0800 Subject: [PATCH 09/10] fix --- paconvert/api_matcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 4f9617711..62bf8ae44 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -4281,11 +4281,11 @@ def generate_code(self, kwargs): return GenericMatcher.generate_code(self, kwargs) -class SoftminMatcher(SoftmaxMatcher): +class SoftminMatcher(BaseMatcher): def generate_code(self, kwargs): self.paddle_api = "paddle_aux.Softmin" self.write_aux_code() - return super().generate_code(kwargs) + return GenericMatcher.generate_code(self, kwargs) def generate_aux_code(self): CODE_TEMPLATE = textwrap.dedent( From 158f9240d1fde2a4297d247bb59007f7c88f90f6 Mon Sep 17 00:00:00 2001 From: starfall <1186454801@qq.com> Date: Tue, 26 Nov 2024 22:18:37 +0800 Subject: [PATCH 10/10] fix --- paconvert/api_matcher.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 62bf8ae44..31a892fa4 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -4297,15 +4297,11 @@ def _get_softmax_dim(axis: int) -> int: ret = 1 return ret - def forward(self,x): - if self._axis is None: - return paddle.nn.functional.softmax(x, _get_softmax_dim(x.ndim)) - return paddle.nn.functional.softmax(x, self._axis) - setattr(paddle.nn.Softmax, 'forward', forward) - class Softmin(paddle.nn.Softmax): def forward(self, x): - return super().forward(-x) + if self._axis is None: + return paddle.nn.functional.softmax(-x, _get_softmax_dim(x.ndim)) + return paddle.nn.functional.softmax(-x, self._axis) """ ) return CODE_TEMPLATE