Skip to content
4 changes: 4 additions & 0 deletions paddle/phi/kernels/cpu/argsort_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,9 @@ PD_REGISTER_KERNEL(argsort_grad,
phi::ArgsortGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
uint8_t,
int16_t,
int,
int64_t) {}
14 changes: 12 additions & 2 deletions paddle/phi/kernels/cpu/argsort_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,17 @@ void ArgsortKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(
argsort, CPU, ALL_LAYOUT, phi::ArgsortKernel, float, double, int, int64_t) {
PD_REGISTER_KERNEL(argsort,
CPU,
ALL_LAYOUT,
phi::ArgsortKernel,
float,
double,
int,
int64_t,
int16_t,
uint8_t,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/argsort_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -232,5 +232,7 @@ PD_REGISTER_KERNEL(argsort_grad,
double,
int,
int64_t,
uint8_t,
int16_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/argsort_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,8 @@ PD_REGISTER_KERNEL(argsort,
double,
int,
int64_t,
uint8_t,
int16_t,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.

from .tensor.compat import (
sort,
split,
)

__all__ = [
'split',
'sort',
]
107 changes: 106 additions & 1 deletion python/paddle/tensor/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, NamedTuple

import paddle
from paddle import _C_ops
Expand Down Expand Up @@ -64,6 +64,7 @@ def split(
To use the original split of paddle, please consider `paddle.split`

Examples:

.. code-block:: python

>>> import paddle
Expand Down Expand Up @@ -211,3 +212,107 @@ def GetShapeOnDimInRange(shape, dim: int) -> int:
split_size_or_sections
)
return tuple(_C_ops.split(tensor, split_size_or_sections, dim))


class SortRetType(NamedTuple):
values: Tensor
indices: Tensor


def _check_out_status(
out: Tensor | tuple[Tensor, Tensor] | list[Tensor],
expect_multiple: bool = False,
):
if out is None:
return
if not in_dynamic_mode():
raise RuntimeError(
"Using `out` static graph CINN backend is currently not supported. Directly return the tensor tuple instead.\n"
)
if expect_multiple:
if not isinstance(out, (tuple, list)) or len(out) != 2:
raise TypeError(
f"Expected a list or tuple of two tensors, got {type(out)} instead."
)
if not (
isinstance(out[0], paddle.Tensor)
and isinstance(out[1], paddle.Tensor)
):
raise TypeError(
f"Expected Tensor type in the tuple/list, got ({type(out[0])}, {type(out[1])}) instead."
)
else:
if not isinstance(out, paddle.Tensor):
raise TypeError(f"Expected a Tensor, got {type(out)} instead.")


@ForbidKeywordsDecorator(
illegal_keys={'x', 'axis'},
func_name="paddle.compat.sort",
correct_name='paddle.sort',
)
def sort(
input: Tensor,
dim: int = -1,
descending: bool = False,
stable: bool = False,
out=None,
) -> SortRetType:
"""

Sorts the input along the given dimension, and returns the sorted output and indices tensor. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.

Args:
input (Tensor): An input N-D Tensor with type float32, float64, int16,
int32, int64, uint8, float16, bfloat16
dim (int, optional): Dimension to compute indices along. The effective range
is [-R, R), where R is Rank(x). when dim<0, it works the same way
as dim+R. Default is -1.
descending (bool, optional) : Descending is a flag, if set to true,
algorithm will sort by descending order, else sort by
ascending order. Default is false.
stable (bool, optional): Whether to use stable sorting algorithm or not.
When using stable sorting algorithm, the order of equivalent elements
will be preserved. Default is False.
out (tuple, optional) : the output tuple/list of (Tensor, Tensor) that
can be optionally given to be used as output buffers

Returns:
SortRetType, a named tuple which contains `values` and `indices`, can be accessed through either indexing
(e.g. `result[0]` for values and `result[1]` for indices), or by `result.values` & `result.indices`

Examples:

.. code-block:: python

>>> import paddle

>>> x = paddle.to_tensor([[5,8,9,5],
... [0,0,1,7],
... [6,9,2,4]],
... dtype='float32')
>>> out1 = paddle.compat.sort(input=x, dim=-1)
>>> out2 = paddle.compat.sort(x, 1, descending=True)
>>> out1
SortRetType(values=Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
[[5., 5., 8., 9.],
[0., 0., 1., 7.],
[2., 4., 6., 9.]]), indices=Tensor(shape=[3, 4], dtype=int64, place=Place(cpu), stop_gradient=True,
[[0, 3, 1, 2],
[0, 1, 2, 3],
[2, 3, 0, 1]]))
>>> out2
SortRetType(values=Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
[[9., 8., 5., 5.],
[7., 1., 0., 0.],
[9., 6., 4., 2.]]), indices=Tensor(shape=[3, 4], dtype=int64, place=Place(cpu), stop_gradient=True,
[[2, 1, 0, 3],
[3, 2, 0, 1],
[1, 0, 3, 2]]))
"""
_check_out_status(out, expect_multiple=True)
outputs, indices = _C_ops.argsort(input, dim, descending, stable)
if out is None:
return SortRetType(values=outputs, indices=indices)
paddle.assign(outputs, out[0])
paddle.assign(indices, out[1])
7 changes: 7 additions & 0 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from paddle import Tensor
from paddle._typing import DTypeLike

from paddle.utils.decorator_utils import ForbidKeywordsDecorator

# from ..base.layers import has_inf #DEFINE_ALIAS
# from ..base.layers import has_nan #DEFINE_ALIAS

Expand Down Expand Up @@ -623,6 +625,11 @@ def _restrict_nonzero(condition: Tensor, total_true_num: int) -> Tensor:
return _C_ops.restrict_nonzero(condition, total_true_num)


@ForbidKeywordsDecorator(
illegal_keys={'input', 'dim'},
func_name='paddle.sort',
correct_name='paddle.compat.sort',
)
def sort(
x: Tensor,
axis: int = -1,
Expand Down
Loading