diff --git a/paddle/phi/kernels/gpu/argsort_grad_kernel.cu b/paddle/phi/kernels/gpu/argsort_grad_kernel.cu index b8d9df64c23efb..695044c095735e 100644 --- a/paddle/phi/kernels/gpu/argsort_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/argsort_grad_kernel.cu @@ -222,4 +222,5 @@ PD_REGISTER_KERNEL(argsort_grad, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/argsort_kernel.cu b/paddle/phi/kernels/gpu/argsort_kernel.cu index 5102594f98d1e0..3b502a567a499f 100644 --- a/paddle/phi/kernels/gpu/argsort_kernel.cu +++ b/paddle/phi/kernels/gpu/argsort_kernel.cu @@ -61,6 +61,11 @@ namespace cub { template <> struct NumericTraits : BaseTraits {}; + +template <> +struct NumericTraits + : BaseTraits { +}; } // namespace cub #endif @@ -328,6 +333,7 @@ PD_REGISTER_KERNEL(argsort, double, int, int64_t, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->OutputAt(1).SetDataType(phi::DataType::INT64); } diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 6172c0247554d6..5f62f2cc539d36 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -109,6 +109,7 @@ def argsort(x, axis=-1, descending=False, name=None): 'int32', 'int64', 'uint8', + 'uint16', ], 'argsort', ) diff --git a/test/legacy_test/test_argsort_op.py b/test/legacy_test/test_argsort_op.py index ec6db2f6651e99..3a5dff216af4ab 100644 --- a/test/legacy_test/test_argsort_op.py +++ b/test/legacy_test/test_argsort_op.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle import fluid @@ -513,5 +514,54 @@ def test_fp16(self): out = exe.run(feed={'x': x_np}, fetch_list=[out]) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the bfloat16", +) +class TestArgsortBF16OP(OpTest): + def setUp(self): + self.init() + self.op_type = "argsort" + self.python_api = paddle.argsort + self.public_python_api = paddle.argsort + self.python_out_sig = [ + "Out" + ] # python out sig is customized output signature. + self.dtype = np.uint16 + self.descending = False + self.attrs = {"axis": self.axis, "descending": self.descending} + self.x = np.random.rand(*self.input_shape).astype(np.float32) + self.sorted_x = np.sort(self.x, kind='heapsort', axis=self.axis).astype( + np.float32 + ) + self.indices = np.argsort( + self.x, kind='heapsort', axis=self.axis + ).astype(np.float32) + self.inputs = {'X': convert_float_to_uint16(self.x)} + self.outputs = { + 'Out': convert_float_to_uint16(self.sorted_x), + "Indices": convert_float_to_uint16(self.indices), + } + + def init(self): + self.input_shape = [ + 1000, + ] + self.axis = 0 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + ['X'], + 'Out', + ) + + if __name__ == "__main__": unittest.main()