diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index dd045688a..8780d6003 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -1989,7 +1989,16 @@ ] }, "torch.Tensor.resize_": {}, - "torch.Tensor.resize_as_": {}, + "torch.Tensor.resize_as_": { + "Matcher": "TensorResize_as_Matcher", + "args_list": [ + "the_template", + "memory_format" + ], + "kwargs_change": { + "memory_format": "" + } + }, "torch.Tensor.resolve_conj": {}, "torch.Tensor.resolve_neg": {}, "torch.Tensor.retain_grad": { @@ -2027,7 +2036,12 @@ "out" ] }, - "torch.Tensor.round_": {}, + "torch.Tensor.round_": { + "Matcher": "TensorRound_Matcher", + "args_list": [ + "decimals" + ] + }, "torch.Tensor.rsqrt": { "Matcher": "GenericMatcher", "paddle_api": "paddle.Tensor.rsqrt" @@ -2066,7 +2080,23 @@ "src": "values" } }, - "torch.Tensor.scatter_add": {}, + "torch.Tensor.scatter_add": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.Tensor.put_along_axis", + "args_list": [ + "dim", + "index", + "src" + ], + "kwargs_change": { + "dim": "axis", + "index": "indices", + "src": "values" + }, + "paddle_default_kwargs": { + "reduce": "'add'" + } + }, "torch.Tensor.scatter_add_": { "Matcher": "GenericMatcher", "paddle_api": "paddle.Tensor.put_along_axis_", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 4951d4000..f3bdcc154 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -3081,6 +3081,18 @@ def generate_code(self, kwargs): return code +class TensorResize_as_Matcher(BaseMatcher): + def generate_code(self, kwargs): + + API_TEMPLATE = textwrap.dedent( + """ + {}.reshape_({}.shape) + """ + ) + code = API_TEMPLATE.format(self.paddleClass, kwargs["the_template"]) + return code + + class SelectMatcher(BaseMatcher): def generate_code(self, kwargs): if "input" not in kwargs: @@ -3688,6 +3700,23 @@ def get_paddle_class_nodes(self, func, args, kwargs): return "unchange" +class TensorRound_Matcher(BaseMatcher): + def generate_code(self, kwargs): + kwargs["input"] = self.paddleClass + + if "decimals" in kwargs: + API_TEMPLATE = textwrap.dedent( + """ + paddle.assign(({} * (10**{})).round_() / (10**{}), {}) + """ + ) + return API_TEMPLATE.format( + kwargs["input"], kwargs["decimals"], kwargs["decimals"], kwargs["input"] + ) + else: + return "{}.round_()".format(kwargs["input"]) + + class NonzeroMatcher(BaseMatcher): def generate_code(self, kwargs): if "as_tuple" in kwargs and kwargs["as_tuple"] != "(False)": diff --git a/tests/test_Tensor_resize_as_.py b/tests/test_Tensor_resize_as_.py new file mode 100644 index 000000000..b40d9c672 --- /dev/null +++ b/tests/test_Tensor_resize_as_.py @@ -0,0 +1,79 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.resize_as_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.ones([15]) + b = torch.zeros([3, 5]) + result = a.resize_as_(b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.ones([15]) + b = torch.zeros([3, 5]) + result = a.resize_as_(the_template=b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.ones([15]) + b = torch.zeros([3, 5]) + result = a.resize_as_(memory_format=torch.contiguous_format, the_template=b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.ones([15]) + b = torch.zeros([3, 5]) + result = a.resize_as_(b+1, memory_format=torch.contiguous_format) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.ones([15]) + b = torch.zeros([3, 5]) + result = a.resize_as_(the_template=b+1, memory_format=torch.contiguous_format) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_round_.py b/tests/test_Tensor_round_.py new file mode 100644 index 000000000..4fc5ce50c --- /dev/null +++ b/tests/test_Tensor_round_.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.round_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[ 0.9254, -0.6213]]) + result = a.round_() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[ 102003.9254, -12021.6213]]) + result = a.round_(decimals=1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[ 102003.9254, -12021.6213]]) + result = a.round_(decimals=-1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[ 102003.9254, -12021.6213]]) + result = a.round_(decimals=3) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[ 102003.9254, -12021.6213]]) + result = a.round_(decimals=-3) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_scatter_add.py b/tests/test_Tensor_scatter_add.py new file mode 100644 index 000000000..5ce76200c --- /dev/null +++ b/tests/test_Tensor_scatter_add.py @@ -0,0 +1,80 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.scatter_add") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + src = torch.ones((1, 5)) + index = torch.tensor([[0, 1, 2, 0, 0]]) + result = torch.zeros(3, 5, dtype=src.dtype).scatter_add(0, index, src) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + src = torch.ones((2, 5)) + index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]]) + result = torch.zeros(3, 5, dtype=src.dtype).scatter_add(dim=0, index=index, src=src) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + src = torch.ones((2, 5)) + index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]]) + result = torch.zeros(3, 5, dtype=src.dtype).scatter_add(dim=0, src=src, index=index) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + src = torch.ones((1, 5)) + index = torch.tensor([[0, 1, 2, 0, 0]]) + result = torch.zeros(3, 5, dtype=src.dtype).scatter_add(0, index, src) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + src = torch.ones((2, 5)) + index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]]) + result = torch.zeros(2, 5, dtype=src.dtype).scatter_add(dim=1, index=index, src=src) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_Threshold.py b/tests/test_nn_Threshold.py new file mode 100644 index 000000000..9d0bb4099 --- /dev/null +++ b/tests/test_nn_Threshold.py @@ -0,0 +1,114 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.Threshold") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + model = nn.Threshold(0.5, 0.0) + result = model(x) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle does not support value != 0.0 and value is mandatory in torch", + ) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + model = nn.Threshold(threshold=0.5, value=0.0) + result = model(x) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle does not support value != 0.0 and value is mandatory in torch", + ) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + model = nn.Threshold(value=0.0, threshold=0.5) + result = model(x) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle does not support value != 0.0 and value is mandatory in torch", + ) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + model = nn.Threshold(0.5, 0.0, False) + result = model(x) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle does not support value != 0.0 and value is mandatory in torch", + ) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913], + [-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]]) + model = nn.Threshold(threshold=0.5, value=0.1, inplace=False) + result = model(x) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle does not support value != 0.0 and value is mandatory in torch", + )