From d577029cbe33cbb11f828a6e9dc3e6fc84611848 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Thu, 13 Jul 2023 12:54:41 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BD=AC=E6=8D=A2=E8=A7=84=E5=88=99=20No.22=20?= =?UTF-8?q?No.25=20`torch.linalg.matrix=5Frank`=20`torch.qr`=20(#159)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add convert rule No.22 and No.25 * fix ut error * fixed * fixed --- paconvert/api_mapping.json | 31 ++++ paconvert/api_matcher.py | 25 ++++ tests/apibase.py | 24 ++- tests/test_Generator.py | 4 +- tests/test_cuda_current_stream.py | 4 +- tests/test_cuda_event.py | 4 +- tests/test_cuda_get_device_properites.py | 4 +- tests/test_cuda_max_memory_allocated.py | 4 +- tests/test_device.py | 2 +- tests/test_linalg_lstsq.py | 4 +- tests/test_linalg_matrix_rank.py | 141 ++++++++++++++++++ tests/test_nn_AdaptiveMaxPool1d.py | 4 +- tests/test_nn_AdaptiveMaxPool2d.py | 4 +- tests/test_nn_AdaptiveMaxPool3d.py | 4 +- .../test_nn_functional_adaptive_max_pool1d.py | 4 +- .../test_nn_functional_adaptive_max_pool2d.py | 4 +- .../test_nn_functional_adaptive_max_pool3d.py | 4 +- tests/test_optim_Optimizer.py | 4 +- tests/test_optim_Optimizer_load_state_dict.py | 4 +- tests/test_optim_Optimizer_state_dict.py | 4 +- tests/test_optim_Optimizer_step.py | 4 +- tests/test_qr.py | 67 +++++++++ tests/test_utils_dlpack_to_dlpack.py | 4 +- 23 files changed, 333 insertions(+), 25 deletions(-) create mode 100644 tests/test_linalg_matrix_rank.py create mode 100644 tests/test_qr.py diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 3179aca90..87e02d749 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -4784,6 +4784,25 @@ "input": "x" } }, + "torch.linalg.matrix_rank": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.linalg.matrix_rank", + "args_list": [ + "A", + "tol", + "atol", + "rtol", + "hermitian", + "out" + ], + "kwargs_change": { + "A": "x" + }, + "unsupport_args": [ + "atol", + "rtol" + ] + }, "torch.linalg.multi_dot": { "Matcher": "GenericMatcher", "paddle_api": "paddle.linalg.multi_dot", @@ -8332,6 +8351,18 @@ "dim": "axis" } }, + "torch.qr": { + "Matcher": "QrMatcher", + "paddle_api": "paddle.linalg.qr", + "args_list": [ + "input", + "some", + "out" + ], + "kwargs_change": { + "input": "x" + } + }, "torch.rad2deg": { "Matcher": "GenericMatcher", "paddle_api": "paddle.rad2deg", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 3c9c5e67c..86aba13a1 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -3648,6 +3648,31 @@ def generate_code(self, kwargs): return GenericMatcher.generate_code(self, kwargs) +class QrMatcher(BaseMatcher): + def generate_code(self, kwargs): + some_v = kwargs.pop("some") if "some" in kwargs else None + out_v = kwargs.pop("out") if "out" in kwargs else None + + if some_v: + kwargs["mode"] = "'complete'" if some_v != "(False)" else "'reduced'" + + if out_v: + kwargs["x"] = kwargs.pop("input") + API_TEMPLATE = textwrap.dedent( + """ + tmp_q, tmp_r = {}({}) + paddle.assign(tmp_q, {}[0]), paddle.assign(tmp_r, {}[1]) + """ + ) + + code = API_TEMPLATE.format( + self.get_paddle_api(), self.kwargs_to_str(kwargs), out_v, out_v + ) + return code + + return GenericMatcher.generate_code(self, kwargs) + + class RandomSplitMatcher(BaseMatcher): def generate_code(self, kwargs): API_TEMPLATE = textwrap.dedent( diff --git a/tests/apibase.py b/tests/apibase.py index 7d20ac0db..c37d194de 100644 --- a/tests/apibase.py +++ b/tests/apibase.py @@ -37,6 +37,7 @@ def run( compared_tensor_names=None, expect_paddle_code=None, check_value=True, + check_dtype=True, unsupport=False, reason=None, is_aux_api=False, @@ -47,6 +48,7 @@ def run( compared_tensor_names: the list of variant name to be compared expect_paddle_code: the string of expect paddle code check_value: If false, the value will not be checked + check_dtype: If false, the dtype will not be checked unsupport: If true, conversion is not supported reason: the reason why it is not supported is_aux_api: the bool value for api that need Auxiliary code @@ -81,7 +83,11 @@ def run( paddle_result = [loc[name] for name in compared_tensor_names] for i in range(len(compared_tensor_names)): self.compare( - self.pytorch_api, pytorch_result[i], paddle_result[i], check_value + self.pytorch_api, + pytorch_result[i], + paddle_result[i], + check_value, + check_dtype, ) if expect_paddle_code: @@ -90,7 +96,9 @@ def run( convert_paddle_code == expect_paddle_code ), "[{}]: get unexpected code".format(self.pytorch_api) - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): """ compare tensors' data, shape, requires_grad, dtype args: @@ -98,6 +106,7 @@ def compare(self, name, pytorch_result, paddle_result, check_value=True): pytorch_result: pytorch Tensor paddle_result: paddle Tensor check_value: If false, the value will not be checked + check_dtype: If false, the dtype will not be checked """ if isinstance(pytorch_result, (tuple, list)): assert isinstance( @@ -147,11 +156,12 @@ def compare(self, name, pytorch_result, paddle_result, check_value=True): ), "API ({}): shape mismatch, torch shape is {}, paddle shape is {}".format( name, pytorch_numpy.shape, paddle_numpy.shape ) - assert ( - pytorch_numpy.dtype == paddle_numpy.dtype - ), "API ({}): dtype mismatch, torch dtype is {}, paddle dtype is {}".format( - name, pytorch_numpy.dtype, paddle_numpy.dtype - ) + if check_dtype: + assert ( + pytorch_numpy.dtype == paddle_numpy.dtype + ), "API ({}): dtype mismatch, torch dtype is {}, paddle dtype is {}".format( + name, pytorch_numpy.dtype, paddle_numpy.dtype + ) if check_value: assert np.allclose( pytorch_numpy, paddle_numpy diff --git a/tests/test_Generator.py b/tests/test_Generator.py index aef85506d..35a31c434 100644 --- a/tests/test_Generator.py +++ b/tests/test_Generator.py @@ -19,7 +19,9 @@ class GeneratorAPIBase(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): if isinstance(paddle_result, paddle.fluid.libpaddle.Generator): return True return False diff --git a/tests/test_cuda_current_stream.py b/tests/test_cuda_current_stream.py index 39d1af6c4..d448fcfe7 100644 --- a/tests/test_cuda_current_stream.py +++ b/tests/test_cuda_current_stream.py @@ -19,7 +19,9 @@ class cudaCurrentStreamAPI(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): return pytorch_result == paddle_result or isinstance( paddle_result, paddle.fluid.libpaddle.CUDAStream ) diff --git a/tests/test_cuda_event.py b/tests/test_cuda_event.py index 7f27fcf38..35c70104f 100644 --- a/tests/test_cuda_event.py +++ b/tests/test_cuda_event.py @@ -19,7 +19,9 @@ class cudaEventAPI(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): return pytorch_result == paddle_result or isinstance( paddle_result, paddle.fluid.libpaddle.CUDAEvent ) diff --git a/tests/test_cuda_get_device_properites.py b/tests/test_cuda_get_device_properites.py index 224be3d4b..13a763dc6 100644 --- a/tests/test_cuda_get_device_properites.py +++ b/tests/test_cuda_get_device_properites.py @@ -19,7 +19,9 @@ class cudaGetDeviceProperitesAPI(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): return pytorch_result == paddle_result or isinstance( paddle_result, paddle.fluid.libpaddle._gpuDeviceProperties ) diff --git a/tests/test_cuda_max_memory_allocated.py b/tests/test_cuda_max_memory_allocated.py index 50ac9f5fb..b0cc5a2b8 100644 --- a/tests/test_cuda_max_memory_allocated.py +++ b/tests/test_cuda_max_memory_allocated.py @@ -18,7 +18,9 @@ class cudaMaxMemoryAllocatedAPI(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): return pytorch_result == paddle_result diff --git a/tests/test_device.py b/tests/test_device.py index 7a04673a4..a0c159917 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -18,7 +18,7 @@ class DeviceAPIBase(APIBase): - def compare(self, name, pytorch_result, paddle_result, value): + def compare(self, name, pytorch_result, paddle_result, value, dtype): return str(pytorch_result) == str(paddle_result) diff --git a/tests/test_linalg_lstsq.py b/tests/test_linalg_lstsq.py index 72732b57c..9e130d6b2 100644 --- a/tests/test_linalg_lstsq.py +++ b/tests/test_linalg_lstsq.py @@ -19,7 +19,9 @@ class LstsqAPI(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): if isinstance(pytorch_result, (tuple, list)): for i in range(len(pytorch_result)): self.compare(self.pytorch_api, pytorch_result[i], paddle_result[i]) diff --git a/tests/test_linalg_matrix_rank.py b/tests/test_linalg_matrix_rank.py new file mode 100644 index 000000000..183f93885 --- /dev/null +++ b/tests/test_linalg_matrix_rank.py @@ -0,0 +1,141 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.linalg.matrix_rank") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.eye(10) + A[0, 0] = 0 + result = torch.linalg.matrix_rank(A) + """ + ) + # NOTE: torch dtype is int64, paddle dtype is int32 + obj.run(pytorch_code, ["result"], check_dtype=False) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[[[-1.1079, -0.4803, 0.2296], + [ 0.3198, -0.2976, -0.0585], + [-1.6931, -0.3353, -0.2893]], + [[ 1.9757, -0.5959, -1.2041], + [ 0.8443, -0.4916, -1.6574], + [-0.2654, -1.0447, -0.8138]], + [[-0.4111, 1.0973, 0.2275], + [ 1.1851, 1.8233, 0.8187], + [-1.4107, -0.5473, 1.1431]], + [[ 0.0327, -0.8295, 0.0457], + [-0.6286, -0.2507, 0.7292], + [ 0.4075, -1.3918, -0.5015]]], + [[[-2.1256, 0.9310, 1.0743], + [ 1.9577, -0.1513, 0.1668], + [-0.1404, 1.6647, 0.7108]], + [[ 0.9001, 1.6930, -0.4966], + [-1.0432, -1.0742, 1.2273], + [-0.2711, -0.4740, -0.6381]], + [[-1.3099, -1.7540, 0.5443], + [ 0.3565, -2.3821, 0.8638], + [-1.3840, 0.8216, 0.2761]], + [[-0.5989, -0.4732, 1.3252], + [-0.7614, 1.0493, 0.8488], + [-0.1300, 0.1287, 0.6234]]]]) + result = torch.linalg.matrix_rank(A, hermitian=True) + """ + ) + obj.run(pytorch_code, ["result"], check_dtype=False) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[[[-1.1079, -0.4803, 0.2296], + [ 0.3198, -0.2976, -0.0585], + [-1.6931, -0.3353, -0.2893]], + [[ 1.9757, -0.5959, -1.2041], + [ 0.8443, -0.4916, -1.6574], + [-0.2654, -1.0447, -0.8138]], + [[-0.4111, 1.0973, 0.2275], + [ 1.1851, 1.8233, 0.8187], + [-1.4107, -0.5473, 1.1431]], + [[ 0.0327, -0.8295, 0.0457], + [-0.6286, -0.2507, 0.7292], + [ 0.4075, -1.3918, -0.5015]]], + [[[-2.1256, 0.9310, 1.0743], + [ 1.9577, -0.1513, 0.1668], + [-0.1404, 1.6647, 0.7108]], + [[ 0.9001, 1.6930, -0.4966], + [-1.0432, -1.0742, 1.2273], + [-0.2711, -0.4740, -0.6381]], + [[-1.3099, -1.7540, 0.5443], + [ 0.3565, -2.3821, 0.8638], + [-1.3840, 0.8216, 0.2761]], + [[-0.5989, -0.4732, 1.3252], + [-0.7614, 1.0493, 0.8488], + [-0.1300, 0.1287, 0.6234]]]]) + result = torch.linalg.matrix_rank(A, atol=1.0, rtol=0.0) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_dtype=False, + unsupport=True, + reason="paddle does not support `atol` and `rtol`", + ) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[[[-1.1079, -0.4803, 0.2296], + [ 0.3198, -0.2976, -0.0585], + [-1.6931, -0.3353, -0.2893]], + [[ 1.9757, -0.5959, -1.2041], + [ 0.8443, -0.4916, -1.6574], + [-0.2654, -1.0447, -0.8138]], + [[-0.4111, 1.0973, 0.2275], + [ 1.1851, 1.8233, 0.8187], + [-1.4107, -0.5473, 1.1431]], + [[ 0.0327, -0.8295, 0.0457], + [-0.6286, -0.2507, 0.7292], + [ 0.4075, -1.3918, -0.5015]]], + [[[-2.1256, 0.9310, 1.0743], + [ 1.9577, -0.1513, 0.1668], + [-0.1404, 1.6647, 0.7108]], + [[ 0.9001, 1.6930, -0.4966], + [-1.0432, -1.0742, 1.2273], + [-0.2711, -0.4740, -0.6381]], + [[-1.3099, -1.7540, 0.5443], + [ 0.3565, -2.3821, 0.8638], + [-1.3840, 0.8216, 0.2761]], + [[-0.5989, -0.4732, 1.3252], + [-0.7614, 1.0493, 0.8488], + [-0.1300, 0.1287, 0.6234]]]]) + result = torch.empty((2, 4), dtype=torch.int64) + torch.linalg.matrix_rank(A, hermitian=True, out=result) + """ + ) + obj.run(pytorch_code, ["result"], check_dtype=False) diff --git a/tests/test_nn_AdaptiveMaxPool1d.py b/tests/test_nn_AdaptiveMaxPool1d.py index 5feafccc0..3d782a528 100644 --- a/tests/test_nn_AdaptiveMaxPool1d.py +++ b/tests/test_nn_AdaptiveMaxPool1d.py @@ -19,7 +19,9 @@ class MaxPoolAPI(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): if isinstance(pytorch_result, (tuple, list)): for i in range(len(pytorch_result)): self.compare(self.pytorch_api, pytorch_result[i], paddle_result[i]) diff --git a/tests/test_nn_AdaptiveMaxPool2d.py b/tests/test_nn_AdaptiveMaxPool2d.py index b2ed9f2ee..a73c6f233 100644 --- a/tests/test_nn_AdaptiveMaxPool2d.py +++ b/tests/test_nn_AdaptiveMaxPool2d.py @@ -19,7 +19,9 @@ class MaxPoolAPI(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): if isinstance(pytorch_result, (tuple, list)): for i in range(len(pytorch_result)): self.compare(self.pytorch_api, pytorch_result[i], paddle_result[i]) diff --git a/tests/test_nn_AdaptiveMaxPool3d.py b/tests/test_nn_AdaptiveMaxPool3d.py index 655feb59a..834852c87 100644 --- a/tests/test_nn_AdaptiveMaxPool3d.py +++ b/tests/test_nn_AdaptiveMaxPool3d.py @@ -19,7 +19,9 @@ class MaxPoolAPI(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): if isinstance(pytorch_result, (tuple, list)): for i in range(len(pytorch_result)): self.compare(self.pytorch_api, pytorch_result[i], paddle_result[i]) diff --git a/tests/test_nn_functional_adaptive_max_pool1d.py b/tests/test_nn_functional_adaptive_max_pool1d.py index e1bf329ea..5cf52b282 100644 --- a/tests/test_nn_functional_adaptive_max_pool1d.py +++ b/tests/test_nn_functional_adaptive_max_pool1d.py @@ -19,7 +19,9 @@ class MaxPoolAPI(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): if isinstance(pytorch_result, (tuple, list)): for i in range(len(pytorch_result)): self.compare(self.pytorch_api, pytorch_result[i], paddle_result[i]) diff --git a/tests/test_nn_functional_adaptive_max_pool2d.py b/tests/test_nn_functional_adaptive_max_pool2d.py index 6501aa46d..7b11b2192 100644 --- a/tests/test_nn_functional_adaptive_max_pool2d.py +++ b/tests/test_nn_functional_adaptive_max_pool2d.py @@ -19,7 +19,9 @@ class MaxPoolAPI(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): if isinstance(pytorch_result, (tuple, list)): for i in range(len(pytorch_result)): self.compare(self.pytorch_api, pytorch_result[i], paddle_result[i]) diff --git a/tests/test_nn_functional_adaptive_max_pool3d.py b/tests/test_nn_functional_adaptive_max_pool3d.py index b2ee25481..4f16771dc 100644 --- a/tests/test_nn_functional_adaptive_max_pool3d.py +++ b/tests/test_nn_functional_adaptive_max_pool3d.py @@ -19,7 +19,9 @@ class MaxPoolAPI(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): if isinstance(pytorch_result, (tuple, list)): for i in range(len(pytorch_result)): self.compare(self.pytorch_api, pytorch_result[i], paddle_result[i]) diff --git a/tests/test_optim_Optimizer.py b/tests/test_optim_Optimizer.py index eec582625..ad4fab883 100644 --- a/tests/test_optim_Optimizer.py +++ b/tests/test_optim_Optimizer.py @@ -19,7 +19,9 @@ class optimOptimizerAPIBase(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): if isinstance(paddle_result, paddle.optimizer.optimizer.Optimizer): return True return False diff --git a/tests/test_optim_Optimizer_load_state_dict.py b/tests/test_optim_Optimizer_load_state_dict.py index 522f7c2b2..29a25423e 100644 --- a/tests/test_optim_Optimizer_load_state_dict.py +++ b/tests/test_optim_Optimizer_load_state_dict.py @@ -18,7 +18,9 @@ class optimOptimizerLoadStateDictAPIBase(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): return pytorch_result["state"] == paddle_result diff --git a/tests/test_optim_Optimizer_state_dict.py b/tests/test_optim_Optimizer_state_dict.py index a0bed38d7..275380f00 100644 --- a/tests/test_optim_Optimizer_state_dict.py +++ b/tests/test_optim_Optimizer_state_dict.py @@ -18,7 +18,9 @@ class optimOptimizerStateDictAPIBase(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): return pytorch_result["state"] == paddle_result diff --git a/tests/test_optim_Optimizer_step.py b/tests/test_optim_Optimizer_step.py index 724d0095d..f07de13af 100644 --- a/tests/test_optim_Optimizer_step.py +++ b/tests/test_optim_Optimizer_step.py @@ -18,7 +18,9 @@ class optimOptimizerAPIBase(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): if paddle_result == pytorch_result: return True return False diff --git a/tests/test_qr.py b/tests/test_qr.py new file mode 100644 index 000000000..d1969d136 --- /dev/null +++ b/tests/test_qr.py @@ -0,0 +1,67 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.qr") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) + q, r = torch.qr(a) + """ + ) + obj.run(pytorch_code, ["q", "r"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) + q, r = torch.qr(a, some=False) + """ + ) + obj.run(pytorch_code, ["q", "r"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) + q = torch.empty((3, 3), dtype=torch.float32) + r = torch.empty((3, 3), dtype=torch.float32) + torch.qr(a, some=False, out=(q, r)) + """ + ) + obj.run(pytorch_code, ["q", "r"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) + q = torch.empty((3, 3), dtype=torch.float32) + r = torch.empty((3, 3), dtype=torch.float32) + result = torch.qr(a, some=False, out=(q, r)) + """ + ) + obj.run(pytorch_code, ["result", "q", "r"]) diff --git a/tests/test_utils_dlpack_to_dlpack.py b/tests/test_utils_dlpack_to_dlpack.py index 66e237e28..34216626f 100644 --- a/tests/test_utils_dlpack_to_dlpack.py +++ b/tests/test_utils_dlpack_to_dlpack.py @@ -18,7 +18,9 @@ class DLPackAPIBase(APIBase): - def compare(self, name, pytorch_result, paddle_result, check_value=True): + def compare( + self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True + ): if type(paddle_result).__name__ == "PyCapsule": return True return False