Skip to content

Commit 8a17069

Browse files
committed
[API-Compat] Updated forbid-keyword decorator
1 parent a637f58 commit 8a17069

File tree

2 files changed

+129
-33
lines changed

2 files changed

+129
-33
lines changed

python/paddle/tensor/compat.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@
1616

1717
from typing import TYPE_CHECKING, NamedTuple
1818

19+
import paddle
1920
from paddle import _C_ops
2021

22+
from ..framework import (
23+
in_dynamic_mode,
24+
)
25+
2126
if TYPE_CHECKING:
2227

2328
from paddle import Tensor
@@ -32,6 +37,33 @@ class SortRetType(NamedTuple):
3237
indices: Tensor
3338

3439

40+
def _check_out_status(
41+
out: Tensor | tuple[Tensor, Tensor] | list[Tensor],
42+
expect_multiple: bool = False,
43+
):
44+
if out is None:
45+
return
46+
if not in_dynamic_mode():
47+
raise RuntimeError(
48+
"Using `out` static graph CINN backend is currently not supported. Directly return the tensor tuple instead.\n"
49+
)
50+
if expect_multiple:
51+
if not isinstance(out, (tuple, list)) or len(out) != 2:
52+
raise TypeError(
53+
f"Expected a list or tuple of two tensors, got {type(out)} instead."
54+
)
55+
if not (
56+
isinstance(out[0], paddle.Tensor)
57+
and isinstance(out[1], paddle.Tensor)
58+
):
59+
raise TypeError(
60+
f"Expected Tensor type in the tuple/list, got ({type(out[0])}, {type(out[1])}) instead."
61+
)
62+
else:
63+
if not isinstance(out, paddle.Tensor):
64+
raise TypeError(f"Expected a Tensor, got {type(out)} instead.")
65+
66+
3567
@ForbidKeywordsDecorator(
3668
illegal_keys={'x', 'axis'},
3769
func_name="paddle.compat.sort",
@@ -42,6 +74,7 @@ def sort(
4274
dim: int = -1,
4375
descending: bool = False,
4476
stable: bool = False,
77+
out=None,
4578
) -> SortRetType:
4679
"""
4780
@@ -59,6 +92,8 @@ def sort(
5992
stable (bool, optional): Whether to use stable sorting algorithm or not.
6093
When using stable sorting algorithm, the order of equivalent elements
6194
will be preserved. Default is False.
95+
out (tuple, optional) : the output tuple/list of (Tensor, Tensor) that
96+
can be optionally given to be used as output buffers
6297
6398
Returns:
6499
SortRetType, a named tuple which contains `values` and `indices`, can be accessed through either indexing
@@ -112,5 +147,9 @@ def sort(
112147
[1, 2, 0, 2],
113148
[2, 0, 2, 0]]]))
114149
"""
150+
_check_out_status(out, expect_multiple=True)
115151
outputs, indices = _C_ops.argsort(input, dim, descending, stable)
116-
return SortRetType(values=outputs, indices=indices)
152+
if out is None:
153+
return SortRetType(values=outputs, indices=indices)
154+
paddle.assign(outputs, out[0])
155+
paddle.assign(indices, out[1])

test/legacy_test/test_compat_sort.py

Lines changed: 89 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,22 @@
1717
import numpy as np
1818

1919
import paddle
20-
from paddle.compat import sort
20+
from paddle.compat import sort as compat_sort
2121

2222

2323
class TestCompatSort(unittest.TestCase):
2424

2525
def _compare_with_origin(
26-
self, input_tensor, dtype, dim, descending, stable
26+
self, input_tensor, dtype, dim, descending, stable, use_out=False
2727
):
28-
sort_res = sort(
29-
input_tensor, dim=dim, descending=descending, stable=stable
30-
)
28+
"""DO NOT set use_out to be True in static graph mode."""
29+
if use_out:
30+
sort_res = (paddle.to_tensor(0), paddle.to_tensor(0))
31+
compat_sort(input_tensor, dim, descending, stable, out=sort_res)
32+
else:
33+
sort_res = compat_sort(
34+
input_tensor, dim=dim, descending=descending, stable=stable
35+
)
3136

3237
origin_vals = paddle.sort(
3338
input_tensor, axis=dim, descending=descending, stable=stable
@@ -37,15 +42,11 @@ def _compare_with_origin(
3742
)
3843
if dtype.find("int"):
3944
np.testing.assert_array_equal(
40-
sort_res.values.numpy(), origin_vals.numpy()
45+
sort_res[0].numpy(), origin_vals.numpy()
4146
)
4247
else:
43-
np.testing.assert_allclose(
44-
sort_res.values.numpy(), origin_vals.numpy()
45-
)
46-
np.testing.assert_array_equal(
47-
sort_res.indices.numpy(), origin_inds.numpy()
48-
)
48+
np.testing.assert_allclose(sort_res[0].numpy(), origin_vals.numpy())
49+
np.testing.assert_array_equal(sort_res[1].numpy(), origin_inds.numpy())
4950

5051
def test_with_origin_static(self):
5152
dtypes = [
@@ -75,7 +76,7 @@ def static_graph_tester(descending, stable):
7576
input_data = paddle.static.data(
7677
name='x', shape=shape, dtype=dtype
7778
)
78-
sort_res = sort(
79+
sort_res = compat_sort(
7980
input_data,
8081
dim=dim,
8182
descending=descending,
@@ -149,19 +150,40 @@ def test_with_origin_dynamic(self, use_static=False):
149150
input_tensor = paddle.randint(0, 255, shape).to(dtype)
150151
else:
151152
input_tensor = paddle.randn(shape, dtype=dtype)
152-
for dim in range(len(shape)):
153-
self._compare_with_origin(
154-
input_tensor, dtype, dim, False, False
155-
)
156-
self._compare_with_origin(
157-
input_tensor, dtype, dim - len(shape), False, True
158-
)
159-
self._compare_with_origin(
160-
input_tensor, dtype, dim, True, False
161-
)
162-
self._compare_with_origin(
163-
input_tensor, dtype, dim - len(shape), True, True
164-
)
153+
for use_out in [False, True]:
154+
for dim in range(len(shape)):
155+
self._compare_with_origin(
156+
input_tensor,
157+
dtype,
158+
dim,
159+
False,
160+
False,
161+
use_out=use_out,
162+
)
163+
self._compare_with_origin(
164+
input_tensor,
165+
dtype,
166+
dim - len(shape),
167+
False,
168+
True,
169+
use_out=use_out,
170+
)
171+
self._compare_with_origin(
172+
input_tensor,
173+
dtype,
174+
dim,
175+
True,
176+
False,
177+
use_out=use_out,
178+
)
179+
self._compare_with_origin(
180+
input_tensor,
181+
dtype,
182+
dim - len(shape),
183+
True,
184+
True,
185+
use_out=use_out,
186+
)
165187

166188
def test_sort_backward(self):
167189
"""test the backward behavior for all data types"""
@@ -177,7 +199,7 @@ def test_sort_backward(self):
177199
y = input_tensor * input_tensor
178200
else:
179201
y = input_tensor + 1
180-
sort_vals, sort_inds = sort(y, dim=dim)
202+
sort_vals, sort_inds = compat_sort(y, dim=dim)
181203
sort_vals.backward()
182204
if input_tensor.place.is_gpu_place():
183205
np.testing.assert_allclose(
@@ -194,7 +216,7 @@ def test_sort_backward(self):
194216
def test_edge_cases(self):
195217
"""Test edge cases and error handling"""
196218
x = paddle.to_tensor([])
197-
sort_res = sort(x, descending=True, stable=True)
219+
sort_res = compat_sort(x, descending=True, stable=True)
198220

199221
np.testing.assert_array_equal(
200222
sort_res.values.numpy(), np.array([], dtype=np.float32)
@@ -204,7 +226,7 @@ def test_edge_cases(self):
204226
)
205227

206228
x = paddle.to_tensor(1)
207-
sort_res = sort(input=x, stable=True)
229+
sort_res = compat_sort(input=x, stable=True)
208230

209231
np.testing.assert_array_equal(
210232
sort_res.values.numpy(), np.array(1, dtype=np.float32)
@@ -213,8 +235,8 @@ def test_edge_cases(self):
213235
sort_res.indices.numpy(), np.array(0, dtype=np.int64)
214236
)
215237

216-
msg_gt_1 = "paddle.sort() received unexpected keyword arguments 'input', 'dim'. \nDid you mean to use paddle.compat.sort() instead?"
217-
msg_gt_2 = "paddle.compat.sort() received unexpected keyword arguments 'x', 'axis'. \nDid you mean to use paddle.sort() instead?"
238+
msg_gt_1 = "paddle.sort() received unexpected keyword arguments 'dim', 'input'. \nDid you mean to use paddle.compat.sort() instead?"
239+
msg_gt_2 = "paddle.compat.sort() received unexpected keyword arguments 'axis', 'x'. \nDid you mean to use paddle.sort() instead?"
218240

219241
# invalid split sections
220242
with self.assertRaises(TypeError) as cm:
@@ -223,9 +245,44 @@ def test_edge_cases(self):
223245

224246
# invalid split axis
225247
with self.assertRaises(TypeError) as cm:
226-
sort(x=paddle.to_tensor([2, 1, 3]), axis=0)
248+
compat_sort(x=paddle.to_tensor([2, 1, 3]), axis=0)
227249
self.assertEqual(str(cm.exception), msg_gt_2)
228250

251+
def test_wrong_out_input(dim, out_input):
252+
with self.assertRaises(TypeError) as cm:
253+
compat_sort(paddle.to_tensor([1, 2]), out=out_input)
254+
255+
test_wrong_out_input(0, [0, paddle.to_tensor(0)])
256+
test_wrong_out_input(0, paddle.to_tensor(0))
257+
test_wrong_out_input(None, 0)
258+
test_wrong_out_input(None, (paddle.to_tensor(0),))
259+
260+
paddle.enable_static()
261+
with (
262+
self.assertRaises(RuntimeError) as cm,
263+
paddle.static.program_guard(paddle.static.Program()),
264+
):
265+
x = paddle.static.data(name='x', shape=[None, 6], dtype='float32')
266+
result0, result1 = compat_sort(
267+
paddle.arange(24),
268+
out=(
269+
paddle.zeros([24]),
270+
paddle.zeros([24], dtype=paddle.int64),
271+
),
272+
)
273+
274+
place = (
275+
paddle.CUDAPlace(0)
276+
if paddle.is_compiled_with_cuda()
277+
else paddle.CPUPlace()
278+
)
279+
paddle.static.Executor(place).run()
280+
self.assertEqual(
281+
str(cm.exception),
282+
"Using `out` static graph CINN backend is currently not supported. Directly return the tensor tuple instead.\n",
283+
)
284+
paddle.disable_static()
285+
229286

230287
if __name__ == "__main__":
231288
unittest.main()

0 commit comments

Comments
 (0)