diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 9111fe8eda5af1..a138b8d52324f1 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -579,6 +579,7 @@ kthvalue, masked_select, mode, + msort, nonzero, searchsorted, sort, @@ -879,6 +880,7 @@ 'summary', 'flops', 'sort', + 'msort', 'searchsorted', 'bucketize', 'split', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 75d2882a04006f..3f0495ac94fffc 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -459,6 +459,7 @@ kthvalue, masked_select, mode, + msort, nonzero, searchsorted, sort, @@ -726,6 +727,7 @@ 'index_select', 'nonzero', 'sort', + 'msort', 'index_sample', 'mean', 'std', diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 3837d7595f8cc7..6b91b36f40fa3a 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -676,6 +676,44 @@ def sort( return out +def msort(input: Tensor) -> Tensor: + """ + + Sorts the input along the given axis = 0, and returns the sorted output tensor. The sort algorithm is ascending. + + Args: + input (Tensor): An input N-D Tensor with type float32, float64, int16, + int32, int64, uint8. + + Returns: + Tensor, sorted tensor(with the same shape and data type as ``input``). + + Examples: + + .. code-block:: python + + >>> import paddle + + >>> x = paddle.to_tensor([[[5,8,9,5], + ... [0,0,1,7], + ... [6,9,2,4]], + ... [[5,2,4,2], + ... [4,7,7,9], + ... [1,7,0,6]]], + ... dtype='float32') + >>> out1 = paddle.msort(input=x) + >>> print(out1.numpy()) + [[[5. 2. 4. 2.] + [0. 0. 1. 7.] + [1. 7. 0. 4.]] + [[5. 8. 9. 5.] + [4. 7. 7. 9.] + [6. 9. 2. 6.]]] + """ + + return sort(input, axis=0) + + def mode( x: Tensor, axis: int = -1, keepdim: bool = False, name: str | None = None ) -> tuple[Tensor, Tensor]: diff --git a/test/legacy_test/test_msort_op.py b/test/legacy_test/test_msort_op.py new file mode 100644 index 00000000000000..aac9e4764e2702 --- /dev/null +++ b/test/legacy_test/test_msort_op.py @@ -0,0 +1,74 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import base +from paddle.base import core + + +class TestMsortOnCPU(unittest.TestCase): + def setUp(self): + self.place = core.CPUPlace() + + def test_api_0(self): + with base.program_guard(base.Program()): + input = paddle.static.data( + name="input", shape=[2, 3, 4], dtype="float32" + ) + output = paddle.msort(input=input) + exe = base.Executor(self.place) + data = np.array( + [ + [[5, 8, 9, 5], [0, 0, 1, 7], [6, 9, 2, 4]], + [[5, 2, 4, 2], [4, 7, 7, 9], [1, 7, 0, 6]], + ], + dtype='float32', + ) + (result,) = exe.run(feed={'input': data}, fetch_list=[output]) + np_result = np.sort(result, axis=0) + self.assertEqual((result == np_result).all(), True) + + +class TestMsortOnGPU(TestMsortOnCPU): + def init_place(self): + if core.is_compiled_with_cuda(): + self.place = core.CUDAPlace(0) + else: + self.place = core.CPUPlace() + + +class TestMsortDygraph(unittest.TestCase): + def setUp(self): + self.input_data = np.random.rand(10, 10) + if core.is_compiled_with_cuda(): + self.place = core.CUDAPlace(0) + else: + self.place = core.CPUPlace() + + def test_api_0(self): + paddle.disable_static(self.place) + var_x = paddle.to_tensor(self.input_data) + out = paddle.msort(input=var_x) + self.assertEqual( + (np.sort(self.input_data, axis=0) == out.numpy()).all(), True + ) + paddle.enable_static() + + +if __name__ == '__main__': + unittest.main()