Skip to content

Commit

Permalink
转换规则 No.22 No.25 torch.linalg.matrix_rank torch.qr (#159)
Browse files Browse the repository at this point in the history
* add convert rule No.22 and No.25

* fix ut error

* fixed

* fixed
  • Loading branch information
GreatV authored Jul 13, 2023
1 parent ba8b1cf commit d577029
Show file tree
Hide file tree
Showing 23 changed files with 333 additions and 25 deletions.
31 changes: 31 additions & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -4784,6 +4784,25 @@
"input": "x"
}
},
"torch.linalg.matrix_rank": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.linalg.matrix_rank",
"args_list": [
"A",
"tol",
"atol",
"rtol",
"hermitian",
"out"
],
"kwargs_change": {
"A": "x"
},
"unsupport_args": [
"atol",
"rtol"
]
},
"torch.linalg.multi_dot": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.linalg.multi_dot",
Expand Down Expand Up @@ -8332,6 +8351,18 @@
"dim": "axis"
}
},
"torch.qr": {
"Matcher": "QrMatcher",
"paddle_api": "paddle.linalg.qr",
"args_list": [
"input",
"some",
"out"
],
"kwargs_change": {
"input": "x"
}
},
"torch.rad2deg": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.rad2deg",
Expand Down
25 changes: 25 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3648,6 +3648,31 @@ def generate_code(self, kwargs):
return GenericMatcher.generate_code(self, kwargs)


class QrMatcher(BaseMatcher):
def generate_code(self, kwargs):
some_v = kwargs.pop("some") if "some" in kwargs else None
out_v = kwargs.pop("out") if "out" in kwargs else None

if some_v:
kwargs["mode"] = "'complete'" if some_v != "(False)" else "'reduced'"

if out_v:
kwargs["x"] = kwargs.pop("input")
API_TEMPLATE = textwrap.dedent(
"""
tmp_q, tmp_r = {}({})
paddle.assign(tmp_q, {}[0]), paddle.assign(tmp_r, {}[1])
"""
)

code = API_TEMPLATE.format(
self.get_paddle_api(), self.kwargs_to_str(kwargs), out_v, out_v
)
return code

return GenericMatcher.generate_code(self, kwargs)


class RandomSplitMatcher(BaseMatcher):
def generate_code(self, kwargs):
API_TEMPLATE = textwrap.dedent(
Expand Down
24 changes: 17 additions & 7 deletions tests/apibase.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def run(
compared_tensor_names=None,
expect_paddle_code=None,
check_value=True,
check_dtype=True,
unsupport=False,
reason=None,
is_aux_api=False,
Expand All @@ -47,6 +48,7 @@ def run(
compared_tensor_names: the list of variant name to be compared
expect_paddle_code: the string of expect paddle code
check_value: If false, the value will not be checked
check_dtype: If false, the dtype will not be checked
unsupport: If true, conversion is not supported
reason: the reason why it is not supported
is_aux_api: the bool value for api that need Auxiliary code
Expand Down Expand Up @@ -81,7 +83,11 @@ def run(
paddle_result = [loc[name] for name in compared_tensor_names]
for i in range(len(compared_tensor_names)):
self.compare(
self.pytorch_api, pytorch_result[i], paddle_result[i], check_value
self.pytorch_api,
pytorch_result[i],
paddle_result[i],
check_value,
check_dtype,
)

if expect_paddle_code:
Expand All @@ -90,14 +96,17 @@ def run(
convert_paddle_code == expect_paddle_code
), "[{}]: get unexpected code".format(self.pytorch_api)

def compare(self, name, pytorch_result, paddle_result, check_value=True):
def compare(
self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True
):
"""
compare tensors' data, shape, requires_grad, dtype
args:
name: pytorch api name
pytorch_result: pytorch Tensor
paddle_result: paddle Tensor
check_value: If false, the value will not be checked
check_dtype: If false, the dtype will not be checked
"""
if isinstance(pytorch_result, (tuple, list)):
assert isinstance(
Expand Down Expand Up @@ -147,11 +156,12 @@ def compare(self, name, pytorch_result, paddle_result, check_value=True):
), "API ({}): shape mismatch, torch shape is {}, paddle shape is {}".format(
name, pytorch_numpy.shape, paddle_numpy.shape
)
assert (
pytorch_numpy.dtype == paddle_numpy.dtype
), "API ({}): dtype mismatch, torch dtype is {}, paddle dtype is {}".format(
name, pytorch_numpy.dtype, paddle_numpy.dtype
)
if check_dtype:
assert (
pytorch_numpy.dtype == paddle_numpy.dtype
), "API ({}): dtype mismatch, torch dtype is {}, paddle dtype is {}".format(
name, pytorch_numpy.dtype, paddle_numpy.dtype
)
if check_value:
assert np.allclose(
pytorch_numpy, paddle_numpy
Expand Down
4 changes: 3 additions & 1 deletion tests/test_Generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@


class GeneratorAPIBase(APIBase):
def compare(self, name, pytorch_result, paddle_result, check_value=True):
def compare(
self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True
):
if isinstance(paddle_result, paddle.fluid.libpaddle.Generator):
return True
return False
Expand Down
4 changes: 3 additions & 1 deletion tests/test_cuda_current_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@


class cudaCurrentStreamAPI(APIBase):
def compare(self, name, pytorch_result, paddle_result, check_value=True):
def compare(
self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True
):
return pytorch_result == paddle_result or isinstance(
paddle_result, paddle.fluid.libpaddle.CUDAStream
)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_cuda_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@


class cudaEventAPI(APIBase):
def compare(self, name, pytorch_result, paddle_result, check_value=True):
def compare(
self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True
):
return pytorch_result == paddle_result or isinstance(
paddle_result, paddle.fluid.libpaddle.CUDAEvent
)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_cuda_get_device_properites.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@


class cudaGetDeviceProperitesAPI(APIBase):
def compare(self, name, pytorch_result, paddle_result, check_value=True):
def compare(
self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True
):
return pytorch_result == paddle_result or isinstance(
paddle_result, paddle.fluid.libpaddle._gpuDeviceProperties
)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_cuda_max_memory_allocated.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@


class cudaMaxMemoryAllocatedAPI(APIBase):
def compare(self, name, pytorch_result, paddle_result, check_value=True):
def compare(
self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True
):
return pytorch_result == paddle_result


Expand Down
2 changes: 1 addition & 1 deletion tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class DeviceAPIBase(APIBase):
def compare(self, name, pytorch_result, paddle_result, value):
def compare(self, name, pytorch_result, paddle_result, value, dtype):
return str(pytorch_result) == str(paddle_result)


Expand Down
4 changes: 3 additions & 1 deletion tests/test_linalg_lstsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@


class LstsqAPI(APIBase):
def compare(self, name, pytorch_result, paddle_result, check_value=True):
def compare(
self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True
):
if isinstance(pytorch_result, (tuple, list)):
for i in range(len(pytorch_result)):
self.compare(self.pytorch_api, pytorch_result[i], paddle_result[i])
Expand Down
141 changes: 141 additions & 0 deletions tests/test_linalg_matrix_rank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.linalg.matrix_rank")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
A = torch.eye(10)
A[0, 0] = 0
result = torch.linalg.matrix_rank(A)
"""
)
# NOTE: torch dtype is int64, paddle dtype is int32
obj.run(pytorch_code, ["result"], check_dtype=False)


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
A = torch.tensor([[[[-1.1079, -0.4803, 0.2296],
[ 0.3198, -0.2976, -0.0585],
[-1.6931, -0.3353, -0.2893]],
[[ 1.9757, -0.5959, -1.2041],
[ 0.8443, -0.4916, -1.6574],
[-0.2654, -1.0447, -0.8138]],
[[-0.4111, 1.0973, 0.2275],
[ 1.1851, 1.8233, 0.8187],
[-1.4107, -0.5473, 1.1431]],
[[ 0.0327, -0.8295, 0.0457],
[-0.6286, -0.2507, 0.7292],
[ 0.4075, -1.3918, -0.5015]]],
[[[-2.1256, 0.9310, 1.0743],
[ 1.9577, -0.1513, 0.1668],
[-0.1404, 1.6647, 0.7108]],
[[ 0.9001, 1.6930, -0.4966],
[-1.0432, -1.0742, 1.2273],
[-0.2711, -0.4740, -0.6381]],
[[-1.3099, -1.7540, 0.5443],
[ 0.3565, -2.3821, 0.8638],
[-1.3840, 0.8216, 0.2761]],
[[-0.5989, -0.4732, 1.3252],
[-0.7614, 1.0493, 0.8488],
[-0.1300, 0.1287, 0.6234]]]])
result = torch.linalg.matrix_rank(A, hermitian=True)
"""
)
obj.run(pytorch_code, ["result"], check_dtype=False)


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
A = torch.tensor([[[[-1.1079, -0.4803, 0.2296],
[ 0.3198, -0.2976, -0.0585],
[-1.6931, -0.3353, -0.2893]],
[[ 1.9757, -0.5959, -1.2041],
[ 0.8443, -0.4916, -1.6574],
[-0.2654, -1.0447, -0.8138]],
[[-0.4111, 1.0973, 0.2275],
[ 1.1851, 1.8233, 0.8187],
[-1.4107, -0.5473, 1.1431]],
[[ 0.0327, -0.8295, 0.0457],
[-0.6286, -0.2507, 0.7292],
[ 0.4075, -1.3918, -0.5015]]],
[[[-2.1256, 0.9310, 1.0743],
[ 1.9577, -0.1513, 0.1668],
[-0.1404, 1.6647, 0.7108]],
[[ 0.9001, 1.6930, -0.4966],
[-1.0432, -1.0742, 1.2273],
[-0.2711, -0.4740, -0.6381]],
[[-1.3099, -1.7540, 0.5443],
[ 0.3565, -2.3821, 0.8638],
[-1.3840, 0.8216, 0.2761]],
[[-0.5989, -0.4732, 1.3252],
[-0.7614, 1.0493, 0.8488],
[-0.1300, 0.1287, 0.6234]]]])
result = torch.linalg.matrix_rank(A, atol=1.0, rtol=0.0)
"""
)
obj.run(
pytorch_code,
["result"],
check_dtype=False,
unsupport=True,
reason="paddle does not support `atol` and `rtol`",
)


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
A = torch.tensor([[[[-1.1079, -0.4803, 0.2296],
[ 0.3198, -0.2976, -0.0585],
[-1.6931, -0.3353, -0.2893]],
[[ 1.9757, -0.5959, -1.2041],
[ 0.8443, -0.4916, -1.6574],
[-0.2654, -1.0447, -0.8138]],
[[-0.4111, 1.0973, 0.2275],
[ 1.1851, 1.8233, 0.8187],
[-1.4107, -0.5473, 1.1431]],
[[ 0.0327, -0.8295, 0.0457],
[-0.6286, -0.2507, 0.7292],
[ 0.4075, -1.3918, -0.5015]]],
[[[-2.1256, 0.9310, 1.0743],
[ 1.9577, -0.1513, 0.1668],
[-0.1404, 1.6647, 0.7108]],
[[ 0.9001, 1.6930, -0.4966],
[-1.0432, -1.0742, 1.2273],
[-0.2711, -0.4740, -0.6381]],
[[-1.3099, -1.7540, 0.5443],
[ 0.3565, -2.3821, 0.8638],
[-1.3840, 0.8216, 0.2761]],
[[-0.5989, -0.4732, 1.3252],
[-0.7614, 1.0493, 0.8488],
[-0.1300, 0.1287, 0.6234]]]])
result = torch.empty((2, 4), dtype=torch.int64)
torch.linalg.matrix_rank(A, hermitian=True, out=result)
"""
)
obj.run(pytorch_code, ["result"], check_dtype=False)
4 changes: 3 additions & 1 deletion tests/test_nn_AdaptiveMaxPool1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@


class MaxPoolAPI(APIBase):
def compare(self, name, pytorch_result, paddle_result, check_value=True):
def compare(
self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True
):
if isinstance(pytorch_result, (tuple, list)):
for i in range(len(pytorch_result)):
self.compare(self.pytorch_api, pytorch_result[i], paddle_result[i])
Expand Down
4 changes: 3 additions & 1 deletion tests/test_nn_AdaptiveMaxPool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@


class MaxPoolAPI(APIBase):
def compare(self, name, pytorch_result, paddle_result, check_value=True):
def compare(
self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True
):
if isinstance(pytorch_result, (tuple, list)):
for i in range(len(pytorch_result)):
self.compare(self.pytorch_api, pytorch_result[i], paddle_result[i])
Expand Down
4 changes: 3 additions & 1 deletion tests/test_nn_AdaptiveMaxPool3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@


class MaxPoolAPI(APIBase):
def compare(self, name, pytorch_result, paddle_result, check_value=True):
def compare(
self, name, pytorch_result, paddle_result, check_value=True, check_dtype=True
):
if isinstance(pytorch_result, (tuple, list)):
for i in range(len(pytorch_result)):
self.compare(self.pytorch_api, pytorch_result[i], paddle_result[i])
Expand Down
Loading

0 comments on commit d577029

Please sign in to comment.