Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No.55 complete argsort FP16 test, add argsort BF16 support and test #51823

Closed
wants to merge 52 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
08b9625
complete argsort FP16 test, add argsort BF16 support and test
superwinner1 Mar 19, 2023
6d4fd5f
Codestyle
superwinner1 Mar 19, 2023
fa199b7
fix
superwinner1 Mar 20, 2023
163b6f3
fix
superwinner1 Mar 20, 2023
8c8a3d2
fix
superwinner1 Mar 20, 2023
da3cf24
fix
superwinner1 Mar 20, 2023
0e5912c
fix
superwinner1 Mar 21, 2023
eaf3d20
fix
superwinner1 Mar 21, 2023
7b622d6
fix
superwinner1 Mar 21, 2023
9768f90
fix
superwinner1 Mar 21, 2023
29aee20
Merge branch 'PaddlePaddle:develop' into argsort
superwinner1 Mar 21, 2023
db49513
test
superwinner1 Mar 21, 2023
afbc27f
Merge remote-tracking branch 'origin/argsort' into argsort
superwinner1 Mar 21, 2023
06793cb
fix
superwinner1 Mar 22, 2023
dbc92d2
fix
superwinner1 Mar 22, 2023
d49a104
fix
superwinner1 Mar 22, 2023
8d863f0
fix
superwinner1 Mar 23, 2023
c915326
fix
superwinner1 Mar 24, 2023
d2c35de
'fix'
superwinner1 Apr 7, 2023
3782c85
'fix'
superwinner1 Apr 8, 2023
44e8651
'fix'
superwinner1 Apr 8, 2023
650b4f3
'fix'
superwinner1 Apr 16, 2023
b317f17
'fix'
superwinner1 Apr 17, 2023
5271637
'fix'
superwinner1 Apr 17, 2023
0539161
'fix'
superwinner1 Apr 18, 2023
8a0e360
'fix'
superwinner1 Apr 27, 2023
d06273f
'fix'
superwinner1 Apr 28, 2023
f159c51
Merge branch 'PaddlePaddle:develop' into argsort
superwinner1 May 5, 2023
ea230a3
'fix'
superwinner1 May 6, 2023
28434aa
Merge remote-tracking branch 'origin/argsort' into argsort
superwinner1 May 6, 2023
c642997
fix
superwinner1 May 7, 2023
16935c2
'fix'
superwinner1 May 7, 2023
27acc85
fix
superwinner1 May 8, 2023
3b5fd84
'fix'
superwinner1 May 12, 2023
87f67bb
'fix'
superwinner1 May 12, 2023
be56eca
'fix'
superwinner1 May 13, 2023
ee513cf
'fix'
superwinner1 May 14, 2023
d002159
'fix'
superwinner1 May 15, 2023
981e437
'fix'
superwinner1 May 17, 2023
ad2b843
'fix'
superwinner1 May 18, 2023
4f98dce
'fix'
superwinner1 May 18, 2023
3071b63
'fix'
superwinner1 May 21, 2023
ed632c6
'fix'
superwinner1 May 21, 2023
ebad139
'fix'
superwinner1 May 22, 2023
8355b29
'fix'
superwinner1 May 22, 2023
d325a03
'fix'
superwinner1 May 23, 2023
0e3d24e
'fix'
superwinner1 May 23, 2023
d4a8119
'fix'
superwinner1 May 23, 2023
f1306eb
'fix'
superwinner1 May 24, 2023
7e42c6b
'fix'
superwinner1 May 25, 2023
6893f4f
'fix'
superwinner1 May 25, 2023
0fec5d6
'fix'
superwinner1 May 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/argsort_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -222,4 +222,5 @@ PD_REGISTER_KERNEL(argsort_grad,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
12 changes: 11 additions & 1 deletion paddle/phi/kernels/gpu/argsort_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ namespace detail {
template <>
struct radix_key_codec_base<phi::dtype::float16>
: radix_key_codec_integral<phi::dtype::float16, uint16_t> {};

template <>
struct radix_key_codec_base<phi::dtype::bfloat16>
: radix_key_codec_integral<phi::dtype::bfloat16, uint16_t> {};
} // namespace detail
} // namespace rocprim
#else
Expand All @@ -48,6 +52,11 @@ namespace cub {
template <>
struct NumericTraits<phi::dtype::float16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::float16> {};

template <>
struct NumericTraits<phi::dtype::bfloat16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::bfloat16> {
};
} // namespace cub
#endif

Expand Down Expand Up @@ -519,6 +528,7 @@ PD_REGISTER_KERNEL(argsort,
double,
int,
int64_t,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
48 changes: 48 additions & 0 deletions python/paddle/fluid/tests/unittests/test_argsort_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import numpy as np
from eager_op_test import OpTest, convert_float_to_uint16

import paddle
from paddle import fluid
Expand Down Expand Up @@ -392,6 +393,7 @@ def setUp(self):
self.data = np.random.rand(*self.input_shape)

def test_api(self):
paddle.enable_static()
with fluid.program_guard(fluid.Program()):
input = paddle.static.data(
name="input", shape=self.input_shape, dtype="float64"
Expand Down Expand Up @@ -513,5 +515,51 @@ def test_fp16(self):
out = exe.run(feed={'x': x_np}, fetch_list=[out])


@unittest.skipIf(
superwinner1 marked this conversation as resolved.
Show resolved Hide resolved
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and not support the bfloat16",
)
class TestArgsortBF16OP(OpTest):
def setUp(self):
self.init()
self.op_type = "argsort"
self.python_api = paddle.argsort
self.public_python_api = paddle.argsort
self.dtype = np.uint16
self.descending = False
self.attrs = {"axis": self.axis, "descending": self.descending}
self.x = np.random.rand(*self.input_shape).astype(np.float32)
self.sorted_x = np.sort(self.x, kind='heapsort', axis=self.axis).astype(
np.float32
)
self.indices = np.argsort(
self.x, kind='heapsort', axis=self.axis
).astype(np.float32)
self.inputs = {'X': convert_float_to_uint16(self.x)}
self.outputs = {
'Out': convert_float_to_uint16(self.sorted_x),
"Indices": convert_float_to_uint16(self.indices),
}

def init(self):
self.input_shape = [
1000,
]
self.axis = 0

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X'],
'Out',
)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def argsort(x, axis=-1, descending=False, name=None):
'int16',
'int32',
'int64',
'uint16',
'uint8',
],
'argsort',
Expand Down