|
14 | 14 |
|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
17 | | -from typing import TYPE_CHECKING |
| 17 | +from typing import TYPE_CHECKING, NamedTuple |
18 | 18 |
|
19 | 19 | import paddle |
20 | 20 | from paddle import _C_ops |
@@ -64,6 +64,7 @@ def split( |
64 | 64 | To use the original split of paddle, please consider `paddle.split` |
65 | 65 |
|
66 | 66 | Examples: |
| 67 | +
|
67 | 68 | .. code-block:: python |
68 | 69 |
|
69 | 70 | >>> import paddle |
@@ -211,3 +212,107 @@ def GetShapeOnDimInRange(shape, dim: int) -> int: |
211 | 212 | split_size_or_sections |
212 | 213 | ) |
213 | 214 | return tuple(_C_ops.split(tensor, split_size_or_sections, dim)) |
| 215 | + |
| 216 | + |
| 217 | +class SortRetType(NamedTuple): |
| 218 | + values: Tensor |
| 219 | + indices: Tensor |
| 220 | + |
| 221 | + |
| 222 | +def _check_out_status( |
| 223 | + out: Tensor | tuple[Tensor, Tensor] | list[Tensor], |
| 224 | + expect_multiple: bool = False, |
| 225 | +): |
| 226 | + if out is None: |
| 227 | + return |
| 228 | + if not in_dynamic_mode(): |
| 229 | + raise RuntimeError( |
| 230 | + "Using `out` static graph CINN backend is currently not supported. Directly return the tensor tuple instead.\n" |
| 231 | + ) |
| 232 | + if expect_multiple: |
| 233 | + if not isinstance(out, (tuple, list)) or len(out) != 2: |
| 234 | + raise TypeError( |
| 235 | + f"Expected a list or tuple of two tensors, got {type(out)} instead." |
| 236 | + ) |
| 237 | + if not ( |
| 238 | + isinstance(out[0], paddle.Tensor) |
| 239 | + and isinstance(out[1], paddle.Tensor) |
| 240 | + ): |
| 241 | + raise TypeError( |
| 242 | + f"Expected Tensor type in the tuple/list, got ({type(out[0])}, {type(out[1])}) instead." |
| 243 | + ) |
| 244 | + else: |
| 245 | + if not isinstance(out, paddle.Tensor): |
| 246 | + raise TypeError(f"Expected a Tensor, got {type(out)} instead.") |
| 247 | + |
| 248 | + |
| 249 | +@ForbidKeywordsDecorator( |
| 250 | + illegal_keys={'x', 'axis'}, |
| 251 | + func_name="paddle.compat.sort", |
| 252 | + correct_name='paddle.sort', |
| 253 | +) |
| 254 | +def sort( |
| 255 | + input: Tensor, |
| 256 | + dim: int = -1, |
| 257 | + descending: bool = False, |
| 258 | + stable: bool = False, |
| 259 | + out=None, |
| 260 | +) -> SortRetType: |
| 261 | + """ |
| 262 | +
|
| 263 | + Sorts the input along the given dimension, and returns the sorted output and indices tensor. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True. |
| 264 | +
|
| 265 | + Args: |
| 266 | + input (Tensor): An input N-D Tensor with type float32, float64, int16, |
| 267 | + int32, int64, uint8, float16, bfloat16 |
| 268 | + dim (int, optional): Dimension to compute indices along. The effective range |
| 269 | + is [-R, R), where R is Rank(x). when dim<0, it works the same way |
| 270 | + as dim+R. Default is -1. |
| 271 | + descending (bool, optional) : Descending is a flag, if set to true, |
| 272 | + algorithm will sort by descending order, else sort by |
| 273 | + ascending order. Default is false. |
| 274 | + stable (bool, optional): Whether to use stable sorting algorithm or not. |
| 275 | + When using stable sorting algorithm, the order of equivalent elements |
| 276 | + will be preserved. Default is False. |
| 277 | + out (tuple, optional) : the output tuple/list of (Tensor, Tensor) that |
| 278 | + can be optionally given to be used as output buffers |
| 279 | +
|
| 280 | + Returns: |
| 281 | + SortRetType, a named tuple which contains `values` and `indices`, can be accessed through either indexing |
| 282 | + (e.g. `result[0]` for values and `result[1]` for indices), or by `result.values` & `result.indices` |
| 283 | +
|
| 284 | + Examples: |
| 285 | +
|
| 286 | + .. code-block:: python |
| 287 | +
|
| 288 | + >>> import paddle |
| 289 | +
|
| 290 | + >>> x = paddle.to_tensor([[5,8,9,5], |
| 291 | + ... [0,0,1,7], |
| 292 | + ... [6,9,2,4]], |
| 293 | + ... dtype='float32') |
| 294 | + >>> out1 = paddle.compat.sort(input=x, dim=-1) |
| 295 | + >>> out2 = paddle.compat.sort(x, 1, descending=True) |
| 296 | + >>> out1 |
| 297 | + SortRetType(values=Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True, |
| 298 | + [[5., 5., 8., 9.], |
| 299 | + [0., 0., 1., 7.], |
| 300 | + [2., 4., 6., 9.]]), indices=Tensor(shape=[3, 4], dtype=int64, place=Place(cpu), stop_gradient=True, |
| 301 | + [[0, 3, 1, 2], |
| 302 | + [0, 1, 2, 3], |
| 303 | + [2, 3, 0, 1]])) |
| 304 | + >>> out2 |
| 305 | + SortRetType(values=Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True, |
| 306 | + [[9., 8., 5., 5.], |
| 307 | + [7., 1., 0., 0.], |
| 308 | + [9., 6., 4., 2.]]), indices=Tensor(shape=[3, 4], dtype=int64, place=Place(cpu), stop_gradient=True, |
| 309 | + [[2, 1, 0, 3], |
| 310 | + [3, 2, 0, 1], |
| 311 | + [1, 0, 3, 2]])) |
| 312 | + """ |
| 313 | + _check_out_status(out, expect_multiple=True) |
| 314 | + outputs, indices = _C_ops.argsort(input, dim, descending, stable) |
| 315 | + if out is None: |
| 316 | + return SortRetType(values=outputs, indices=indices) |
| 317 | + paddle.assign(outputs, out[0]) |
| 318 | + paddle.assign(indices, out[1]) |
0 commit comments