-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathapibase.py
251 lines (233 loc) · 8.97 KB
/
apibase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import difflib
import os
import sys
import numpy as np
sys.path.append(os.path.dirname(__file__) + "/..")
from paconvert.converter import Converter
class APIBase(object):
def __init__(self, pytorch_api) -> None:
"""
args:
pytorch_api: The corresponding pytorch api
"""
self.pytorch_api = pytorch_api
def run(
self,
pytorch_code,
compared_tensor_names=None,
expect_paddle_code=None,
check_value=True,
check_dtype=True,
check_stop_gradient=True,
rtol=1.0e-6,
atol=0.0,
unsupport=False,
reason=None,
) -> None:
"""
args:
pytorch_code: pytorch code to execute
compared_tensor_names: the list of variant name to be compared
expect_paddle_code: the string of expect paddle code
check_value: If false, the value will not be checked
check_dtype: If false, the dtype will not be checked
check_stop_gradient: If false, the stop gradient will not be checked
unsupport: If true, conversion is not supported
reason: the reason why it is not supported
"""
if unsupport:
assert (
reason is not None
), "Please explain the reason why it is not supported"
paddle_code = self.convert(pytorch_code)
assert ">>>>>>" in paddle_code
return
if compared_tensor_names:
loc = locals()
exec(pytorch_code, locals())
pytorch_result = [loc[name] for name in compared_tensor_names]
paddle_code = self.convert(pytorch_code)
exec(paddle_code, locals())
paddle_result = [loc[name] for name in compared_tensor_names]
for i in range(len(compared_tensor_names)):
self.compare(
self.pytorch_api,
pytorch_result[i],
paddle_result[i],
check_value,
check_dtype,
check_stop_gradient,
rtol,
atol,
)
if expect_paddle_code:
converted_code = self.convert(pytorch_code).strip()
expected_code = expect_paddle_code.strip()
if converted_code != expected_code:
diff = difflib.unified_diff(
expected_code.splitlines(),
converted_code.splitlines(),
fromfile="expected",
tofile="converted",
lineterm="",
)
diff_text = "\n".join(diff)
error_msg = (
f"[{self.pytorch_api}] Code conversion result differs from expectation:\n"
f"{'-'*50}\n"
f"Diff comparison:\n"
f"{diff_text}\n"
f"{'-'*50}"
)
assert converted_code == expected_code, error_msg
def compare(
self,
name,
pytorch_result,
paddle_result,
check_value=True,
check_dtype=True,
check_stop_gradient=True,
rtol=1.0e-6,
atol=0.0,
):
"""
compare tensors' data, shape, requires_grad, dtype
args:
name: pytorch api name
pytorch_result: pytorch Tensor
paddle_result: paddle Tensor
check_value: If false, the value will not be checked
check_dtype: If false, the dtype will not be checked
check_stop_gradient: If false, the stop gradient will not be checked
"""
if isinstance(pytorch_result, dict):
assert isinstance(paddle_result, dict), "paddle result should be dict"
assert len(pytorch_result) == len(
paddle_result
), "paddle result have different length with pytorch"
pytorch_result_k = [k for k in pytorch_result.keys()]
pytorch_result_v = [v for v in pytorch_result.values()]
paddle_result_k = [k for k in paddle_result.keys()]
paddle_result_v = [v for v in paddle_result.values()]
self.compare(
self.pytorch_api,
pytorch_result_k,
paddle_result_k,
check_value,
check_dtype,
check_stop_gradient,
rtol,
atol,
)
self.compare(
self.pytorch_api,
pytorch_result_v,
paddle_result_v,
check_value,
check_dtype,
check_stop_gradient,
rtol,
atol,
)
return
if isinstance(pytorch_result, (tuple, list)):
assert isinstance(
paddle_result, (tuple, list)
), "paddle result should be list/tuple"
assert len(pytorch_result) == len(
paddle_result
), "paddle result have different length with pytorch"
for i in range(len(pytorch_result)):
self.compare(
self.pytorch_api,
pytorch_result[i],
paddle_result[i],
check_value,
check_dtype,
check_stop_gradient,
rtol,
atol,
)
return
if isinstance(pytorch_result, (bool, np.number, int, float, str, type(None))):
assert type(paddle_result) == type(
pytorch_result
), "paddle result's type [{}] should be the same with pytorch's type [{}]".format(
type(paddle_result), type(pytorch_result)
)
if check_value:
assert (
pytorch_result == paddle_result
), "API ({}): pytorch result is {}, but paddle result is {}".format(
name, pytorch_result, paddle_result
)
return
if pytorch_result.requires_grad:
pytorch_numpy, paddle_numpy = (
pytorch_result.detach().numpy(),
paddle_result.numpy(False),
)
elif pytorch_result.is_conj():
pytorch_numpy, paddle_numpy = (
pytorch_result.resolve_conj().numpy(),
paddle_result.numpy(False),
)
else:
(
pytorch_numpy,
paddle_numpy,
) = pytorch_result.cpu().numpy(), paddle_result.numpy(False)
if check_stop_gradient:
assert (
pytorch_result.requires_grad != paddle_result.stop_gradient
), "API ({}): requires grad mismatch, torch tensor's requires_grad is {}, paddle tensor's stop_gradient is {}".format(
name, pytorch_result.requires_grad, paddle_result.stop_gradient
)
assert (
pytorch_numpy.shape == paddle_numpy.shape
), "API ({}): shape mismatch, torch shape is {}, paddle shape is {}".format(
name, pytorch_numpy.shape, paddle_numpy.shape
)
if check_dtype:
assert (
pytorch_numpy.dtype == paddle_numpy.dtype
), "API ({}): dtype mismatch, torch dtype is {}, paddle dtype is {}".format(
name, pytorch_numpy.dtype, paddle_numpy.dtype
)
if check_value:
assert np.allclose(
pytorch_numpy, paddle_numpy, rtol=rtol, atol=atol
), "API ({}): paddle result has diff with pytorch result".format(name)
def convert(self, pytorch_code):
"""
convert pytorch code to paddle code.
args:
pytorch_code: pytorch code to be converted.
return:
paddle code.
"""
if not os.path.exists(os.getcwd() + "/test_project"):
os.makedirs(os.getcwd() + "/test_project")
pytorch_code_path = os.getcwd() + "/test_project/pytorch_temp.py"
paddle_code_path = os.getcwd() + "/test_project/paddle_temp.py"
with open(pytorch_code_path, "w", encoding="UTF-8") as f:
f.write(pytorch_code)
converter = Converter(log_dir="disable")
converter.run(pytorch_code_path, paddle_code_path)
with open(paddle_code_path, "r", encoding="UTF-8") as f:
code = f.read()
return code