From 8863e7b8941a9d2aacefdaa72583dfdb7f63aa22 Mon Sep 17 00:00:00 2001 From: txyugood Date: Mon, 26 Jun 2023 11:13:16 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BD=AC=E6=8D=A2=E8=A7=84=E5=88=99=20No.240-2?= =?UTF-8?q?42.=20(#119)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add 240-242. * refine matcher of torch.utils.data.random_split. * refine matcher of torch.utils.data.random_split. * refine matcher of torch.utils.data.random_split. --- paconvert/api_mapping.json | 29 ++++ paconvert/api_matcher.py | 21 +++ tests/test_utils_data_random_split.py | 181 +++++++++++++++++++++++++ tests/test_utils_dlpack_from_dlpack.py | 115 ++++++++++++++++ tests/test_utils_dlpack_to_dlpack.py | 115 ++++++++++++++++ 5 files changed, 461 insertions(+) create mode 100644 tests/test_utils_data_random_split.py create mode 100644 tests/test_utils_dlpack_from_dlpack.py create mode 100644 tests/test_utils_dlpack_to_dlpack.py diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 9d56fd866..06ead94a7 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -8610,6 +8610,35 @@ "data_source" ] }, + "torch.utils.data.random_split": { + "Matcher": "RandomSplitMatcher", + "paddle_api": "paddle.io.random_split", + "args_list": [ + "dataset", + "lengths", + "generator" + ] + }, + "torch.utils.dlpack.from_dlpack": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.utils.dlpack.from_dlpack", + "args_list": [ + "ext_tensor" + ], + "kwargs_change": { + "ext_tensor": "dlpack" + } + }, + "torch.utils.dlpack.to_dlpack": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.utils.dlpack.to_dlpack", + "args_list": [ + "tensor" + ], + "kwargs_change": { + "tensor": "x" + } + }, "torch.nn.functional.l1_loss": { "Matcher": "SizeAverageMatcher", "paddle_api": "paddle.nn.functional.l1_loss", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 0e5cc28f5..00deaebd2 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -3645,3 +3645,24 @@ class SizeAverageMatcher(BaseMatcher): def generate_code(self, kwargs): process_reduce_and_size_average(kwargs) return GenericMatcher.generate_code(self, kwargs) + + +class RandomSplitMatcher(BaseMatcher): + def generate_code(self, kwargs): + API_TEMPLATE = textwrap.dedent( + """ + dataset_lengths = {} + if sum(dataset_lengths) <= 1: + dataset_lengths = [int(length * {}.__len__()) for length in dataset_lengths] + {}({}) + """ + ) + lenghts_v = kwargs["lengths"].strip("\n") + kwargs["lengths"] = "dataset_lengths" + code = API_TEMPLATE.format( + lenghts_v, + kwargs["dataset"], + self.get_paddle_api(), + self.kwargs_to_str(kwargs), + ) + return code.strip("\n") diff --git a/tests/test_utils_data_random_split.py b/tests/test_utils_data_random_split.py new file mode 100644 index 000000000..074fcc8fd --- /dev/null +++ b/tests/test_utils_data_random_split.py @@ -0,0 +1,181 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.utils.data.random_split") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data import Dataset + + class Data(Dataset): + def __init__(self): + self.x = [0,1,2,3,4,5,6,7,8,9] + + def __getitem__(self, idx): + return self.x[idx] + + def __len__(self): + return len(self.x) + + + data = Data() + + datasets = torch.utils.data.random_split(data, [3, 7]) + + results = [] + for d in datasets: + results.append(d.__len__()) + """ + ) + obj.run(pytorch_code, ["results"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data import Dataset + + class Data(Dataset): + def __init__(self): + self.x = [0,1,2,3,4,5,6,7,8,9] + + def __getitem__(self, idx): + return self.x[idx] + + def __len__(self): + return len(self.x) + + + data = Data() + datasets = torch.utils.data.random_split(data, [3, 3, 4]) + + results = [] + for d in datasets: + results.append(d.__len__()) + """ + ) + obj.run(pytorch_code, ["results"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data import Dataset + + class Data(Dataset): + def __init__(self): + self.x = [0,1,2,3,4,5,6,7,8,9] + + def __getitem__(self, idx): + return self.x[idx] + + def __len__(self): + return len(self.x) + + + data = Data() + lengths = [3, 3, 4] + datasets = torch.utils.data.random_split(data, lengths) + + results = [] + for d in datasets: + results.append(d.__len__()) + """ + ) + obj.run(pytorch_code, ["results"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data import Dataset + + class Data(Dataset): + def __init__(self): + self.x = [0,1,2,3,4,5,6,7,8,9] + + def __getitem__(self, idx): + return self.x[idx] + + def __len__(self): + return len(self.x) + + + data = Data() + lengths = [0.4, 0.4, 0.2] + datasets = torch.utils.data.random_split(data, lengths) + + results = [] + for d in datasets: + results.append(d.__len__()) + """ + ) + obj.run(pytorch_code, ["results"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + + datasets = torch.utils.data.random_split(range(30), [0.4, 0.4, 0.2]) + + results = [] + for d in datasets: + results.append(d.__len__()) + """ + ) + obj.run(pytorch_code, ["results"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + lengths = [0.4, 0.4, 0.2] + data = range(30) + datasets = torch.utils.data.random_split(data, lengths) + + results = [] + for d in datasets: + results.append(d.__len__()) + """ + ) + obj.run(pytorch_code, ["results"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + lengths = [0.4, 0.4, 0.2] + data = range(30) + datasets = torch.utils.data.random_split(data, lengths,generator=torch.Generator().manual_seed(42)) + + results = [] + for d in datasets: + results.append(d.__len__()) + """ + ) + obj.run(pytorch_code, ["results"]) diff --git a/tests/test_utils_dlpack_from_dlpack.py b/tests/test_utils_dlpack_from_dlpack.py new file mode 100644 index 000000000..9b7a12ee8 --- /dev/null +++ b/tests/test_utils_dlpack_from_dlpack.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.utils.dlpack.from_dlpack") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.arange(4).int() + capsule = torch.utils.dlpack.to_dlpack(t) + result = torch.utils.dlpack.from_dlpack(capsule) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.arange(4).long() + capsule = torch.utils.dlpack.to_dlpack(t) + result = torch.utils.dlpack.from_dlpack(capsule) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.arange(4).half() + capsule = torch.utils.dlpack.to_dlpack(t) + result = torch.utils.dlpack.from_dlpack(capsule) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.arange(4).double() + capsule = torch.utils.dlpack.to_dlpack(t) + result = torch.utils.dlpack.from_dlpack(capsule) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.arange(4).float() + capsule = torch.utils.dlpack.to_dlpack(t) + result = torch.utils.dlpack.from_dlpack(capsule) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.arange(4).short() + capsule = torch.utils.dlpack.to_dlpack(t) + result = torch.utils.dlpack.from_dlpack(capsule) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.arange(4).byte() + capsule = torch.utils.dlpack.to_dlpack(t) + result = torch.utils.dlpack.from_dlpack(capsule) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.arange(4).char() + capsule = torch.utils.dlpack.to_dlpack(t) + result = torch.utils.dlpack.from_dlpack(capsule) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_utils_dlpack_to_dlpack.py b/tests/test_utils_dlpack_to_dlpack.py new file mode 100644 index 000000000..66e237e28 --- /dev/null +++ b/tests/test_utils_dlpack_to_dlpack.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + + +class DLPackAPIBase(APIBase): + def compare(self, name, pytorch_result, paddle_result, check_value=True): + if type(paddle_result).__name__ == "PyCapsule": + return True + return False + + +obj = DLPackAPIBase("torch.utils.dlpack.to_dlpack") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.arange(4).int() + result = torch.utils.dlpack.to_dlpack(t) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.arange(4).long() + result = torch.utils.dlpack.to_dlpack(t) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.randn(4, 2).half() + result = torch.utils.dlpack.to_dlpack(t) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.randn(4, 2).double() + result = torch.utils.dlpack.to_dlpack(t) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.randn(3, 3).float() + result = torch.utils.dlpack.to_dlpack(t) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.randn(3, 3).short() + result = torch.utils.dlpack.to_dlpack(t) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.randn(3, 3).byte() + result = torch.utils.dlpack.to_dlpack(t) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.randn(3, 3).char() + result = torch.utils.dlpack.to_dlpack(t) + """ + ) + obj.run(pytorch_code, ["result"])