Skip to content

Commit 39e0b4f

Browse files
authored
[API-Compat] Add paddle.compat.sort and upgrade PHI kernel for argsort (type expansion) (#74558)
* [API-Compat] Added paddle.compat.sort and tested * [API-Compat] Updated EN docs * [API-Compat] Fixed EN doc and updated decorator * [API-Compat] Fixed EN Doc * [API-Compat] Updated forbid-keyword decorator * [API-Compat] Resolved merge conflicts. * [API-Compat] Fixed Doc test * [API-Compat] Fixed compat import * [API-Compat] Resolved merge conflicts * [API-Compat] Resolved failed pre-commit
1 parent 97f063f commit 39e0b4f

File tree

8 files changed

+423
-3
lines changed

8 files changed

+423
-3
lines changed

paddle/phi/kernels/cpu/argsort_grad_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,5 +136,9 @@ PD_REGISTER_KERNEL(argsort_grad,
136136
phi::ArgsortGradKernel,
137137
float,
138138
double,
139+
phi::dtype::float16,
140+
phi::dtype::bfloat16,
141+
uint8_t,
142+
int16_t,
139143
int,
140144
int64_t) {}

paddle/phi/kernels/cpu/argsort_kernel.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,17 @@ void ArgsortKernel(const Context& dev_ctx,
181181

182182
} // namespace phi
183183

184-
PD_REGISTER_KERNEL(
185-
argsort, CPU, ALL_LAYOUT, phi::ArgsortKernel, float, double, int, int64_t) {
184+
PD_REGISTER_KERNEL(argsort,
185+
CPU,
186+
ALL_LAYOUT,
187+
phi::ArgsortKernel,
188+
float,
189+
double,
190+
int,
191+
int64_t,
192+
int16_t,
193+
uint8_t,
194+
phi::dtype::float16,
195+
phi::dtype::bfloat16) {
186196
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
187197
}

paddle/phi/kernels/gpu/argsort_grad_kernel.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,5 +232,7 @@ PD_REGISTER_KERNEL(argsort_grad,
232232
double,
233233
int,
234234
int64_t,
235+
uint8_t,
236+
int16_t,
235237
phi::dtype::float16,
236238
phi::dtype::bfloat16) {}

paddle/phi/kernels/gpu/argsort_kernel.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,8 @@ PD_REGISTER_KERNEL(argsort,
486486
double,
487487
int,
488488
int64_t,
489+
uint8_t,
490+
int16_t,
489491
phi::dtype::float16,
490492
phi::dtype::bfloat16) {
491493
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);

python/paddle/compat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# limitations under the License.
1414

1515
from .tensor.compat import (
16+
sort,
1617
split,
1718
)
1819

1920
__all__ = [
2021
'split',
22+
'sort',
2123
]

python/paddle/tensor/compat.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import TYPE_CHECKING
17+
from typing import TYPE_CHECKING, NamedTuple
1818

1919
import paddle
2020
from paddle import _C_ops
@@ -64,6 +64,7 @@ def split(
6464
To use the original split of paddle, please consider `paddle.split`
6565
6666
Examples:
67+
6768
.. code-block:: python
6869
6970
>>> import paddle
@@ -211,3 +212,107 @@ def GetShapeOnDimInRange(shape, dim: int) -> int:
211212
split_size_or_sections
212213
)
213214
return tuple(_C_ops.split(tensor, split_size_or_sections, dim))
215+
216+
217+
class SortRetType(NamedTuple):
218+
values: Tensor
219+
indices: Tensor
220+
221+
222+
def _check_out_status(
223+
out: Tensor | tuple[Tensor, Tensor] | list[Tensor],
224+
expect_multiple: bool = False,
225+
):
226+
if out is None:
227+
return
228+
if not in_dynamic_mode():
229+
raise RuntimeError(
230+
"Using `out` static graph CINN backend is currently not supported. Directly return the tensor tuple instead.\n"
231+
)
232+
if expect_multiple:
233+
if not isinstance(out, (tuple, list)) or len(out) != 2:
234+
raise TypeError(
235+
f"Expected a list or tuple of two tensors, got {type(out)} instead."
236+
)
237+
if not (
238+
isinstance(out[0], paddle.Tensor)
239+
and isinstance(out[1], paddle.Tensor)
240+
):
241+
raise TypeError(
242+
f"Expected Tensor type in the tuple/list, got ({type(out[0])}, {type(out[1])}) instead."
243+
)
244+
else:
245+
if not isinstance(out, paddle.Tensor):
246+
raise TypeError(f"Expected a Tensor, got {type(out)} instead.")
247+
248+
249+
@ForbidKeywordsDecorator(
250+
illegal_keys={'x', 'axis'},
251+
func_name="paddle.compat.sort",
252+
correct_name='paddle.sort',
253+
)
254+
def sort(
255+
input: Tensor,
256+
dim: int = -1,
257+
descending: bool = False,
258+
stable: bool = False,
259+
out=None,
260+
) -> SortRetType:
261+
"""
262+
263+
Sorts the input along the given dimension, and returns the sorted output and indices tensor. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.
264+
265+
Args:
266+
input (Tensor): An input N-D Tensor with type float32, float64, int16,
267+
int32, int64, uint8, float16, bfloat16
268+
dim (int, optional): Dimension to compute indices along. The effective range
269+
is [-R, R), where R is Rank(x). when dim<0, it works the same way
270+
as dim+R. Default is -1.
271+
descending (bool, optional) : Descending is a flag, if set to true,
272+
algorithm will sort by descending order, else sort by
273+
ascending order. Default is false.
274+
stable (bool, optional): Whether to use stable sorting algorithm or not.
275+
When using stable sorting algorithm, the order of equivalent elements
276+
will be preserved. Default is False.
277+
out (tuple, optional) : the output tuple/list of (Tensor, Tensor) that
278+
can be optionally given to be used as output buffers
279+
280+
Returns:
281+
SortRetType, a named tuple which contains `values` and `indices`, can be accessed through either indexing
282+
(e.g. `result[0]` for values and `result[1]` for indices), or by `result.values` & `result.indices`
283+
284+
Examples:
285+
286+
.. code-block:: python
287+
288+
>>> import paddle
289+
290+
>>> x = paddle.to_tensor([[5,8,9,5],
291+
... [0,0,1,7],
292+
... [6,9,2,4]],
293+
... dtype='float32')
294+
>>> out1 = paddle.compat.sort(input=x, dim=-1)
295+
>>> out2 = paddle.compat.sort(x, 1, descending=True)
296+
>>> out1
297+
SortRetType(values=Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
298+
[[5., 5., 8., 9.],
299+
[0., 0., 1., 7.],
300+
[2., 4., 6., 9.]]), indices=Tensor(shape=[3, 4], dtype=int64, place=Place(cpu), stop_gradient=True,
301+
[[0, 3, 1, 2],
302+
[0, 1, 2, 3],
303+
[2, 3, 0, 1]]))
304+
>>> out2
305+
SortRetType(values=Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
306+
[[9., 8., 5., 5.],
307+
[7., 1., 0., 0.],
308+
[9., 6., 4., 2.]]), indices=Tensor(shape=[3, 4], dtype=int64, place=Place(cpu), stop_gradient=True,
309+
[[2, 1, 0, 3],
310+
[3, 2, 0, 1],
311+
[1, 0, 3, 2]]))
312+
"""
313+
_check_out_status(out, expect_multiple=True)
314+
outputs, indices = _C_ops.argsort(input, dim, descending, stable)
315+
if out is None:
316+
return SortRetType(values=outputs, indices=indices)
317+
paddle.assign(outputs, out[0])
318+
paddle.assign(indices, out[1])

python/paddle/tensor/search.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from paddle import Tensor
4040
from paddle._typing import DTypeLike
4141

42+
from paddle.utils.decorator_utils import ForbidKeywordsDecorator
43+
4244
# from ..base.layers import has_inf #DEFINE_ALIAS
4345
# from ..base.layers import has_nan #DEFINE_ALIAS
4446

@@ -623,6 +625,11 @@ def _restrict_nonzero(condition: Tensor, total_true_num: int) -> Tensor:
623625
return _C_ops.restrict_nonzero(condition, total_true_num)
624626

625627

628+
@ForbidKeywordsDecorator(
629+
illegal_keys={'input', 'dim'},
630+
func_name='paddle.sort',
631+
correct_name='paddle.compat.sort',
632+
)
626633
def sort(
627634
x: Tensor,
628635
axis: int = -1,

0 commit comments

Comments
 (0)