Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds descending order sort to argsort #144

Merged
merged 2 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ how fast this is relative to `std::sort`.

## Sort an array of built-in integers and floats
```cpp
void x86simdsort::qsort(T* arr, size_t size, bool hasnan);
void x86simdsort::qselect(T* arr, size_t k, size_t size, bool hasnan);
void x86simdsort::partial_qsort(T* arr, size_t k, size_t size, bool hasnan);
void x86simdsort::qsort(T* arr, size_t size, bool hasnan, bool descending);
void x86simdsort::qselect(T* arr, size_t k, size_t size, bool hasnan, bool descending);
void x86simdsort::partial_qsort(T* arr, size_t k, size_t size, bool hasnan, bool descending);
```
Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t,
int32_t, double, uint64_t, int64_t]`
Expand All @@ -53,7 +53,7 @@ data types.

## Arg sort routines on arrays
```cpp
std::vector<size_t> arg = x86simdsort::argsort(T* arr, size_t size, bool hasnan);
std::vector<size_t> arg = x86simdsort::argsort(T* arr, size_t size, bool hasnan, bool descending);
std::vector<size_t> arg = x86simdsort::argselect(T* arr, size_t k, size_t size, bool hasnan);
```
Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t,
Expand Down
17 changes: 17 additions & 0 deletions benchmarks/bench-argsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@ static void simdargsort(benchmark::State &state, Args &&...args)
}
}

template <typename T, class... Args>
static void simd_revargsort(benchmark::State &state, Args &&...args)
{
// get args
auto args_tuple = std::make_tuple(std::move(args)...);
size_t arrsize = std::get<0>(args_tuple);
std::string arrtype = std::get<1>(args_tuple);
// set up array
std::vector<T> arr = get_array<T>(arrtype, arrsize);
std::vector<size_t> inx;
// benchmark
for (auto _ : state) {
inx = x86simdsort::argsort(arr.data(), arrsize, false, true);
}
}

template <typename T, class... Args>
static void simd_ordern_argsort(benchmark::State &state, Args &&...args)
{
Expand All @@ -68,6 +84,7 @@ static void simd_ordern_argsort(benchmark::State &state, Args &&...args)

#define BENCH_BOTH(type) \
BENCH_SORT(simdargsort, type) \
BENCH_SORT(simd_revargsort, type) \
BENCH_SORT(simd_ordern_argsort, type) \
BENCH_SORT(scalarargsort, type)

Expand Down
5 changes: 3 additions & 2 deletions lib/x86simdsort-avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
avx2_partial_qsort(arr, k, arrsize, hasnan, descending); \
} \
template <> \
std::vector<size_t> argsort(type *arr, size_t arrsize, bool hasnan) \
std::vector<size_t> argsort( \
type *arr, size_t arrsize, bool hasnan, bool descending) \
{ \
return avx2_argsort(arr, arrsize, hasnan); \
return avx2_argsort(arr, arrsize, hasnan, descending); \
} \
template <> \
std::vector<size_t> argselect( \
Expand Down
18 changes: 12 additions & 6 deletions lib/x86simdsort-internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ namespace avx512 {
bool descending = false);
// argsort
template <typename T>
XSS_HIDE_SYMBOL std::vector<size_t>
argsort(T *arr, size_t arrsize, bool hasnan = false);
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// argselect
template <typename T>
XSS_HIDE_SYMBOL std::vector<size_t>
Expand Down Expand Up @@ -62,8 +64,10 @@ namespace avx2 {
bool descending = false);
// argsort
template <typename T>
XSS_HIDE_SYMBOL std::vector<size_t>
argsort(T *arr, size_t arrsize, bool hasnan = false);
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// argselect
template <typename T>
XSS_HIDE_SYMBOL std::vector<size_t>
Expand Down Expand Up @@ -94,8 +98,10 @@ namespace scalar {
bool descending = false);
// argsort
template <typename T>
XSS_HIDE_SYMBOL std::vector<size_t>
argsort(T *arr, size_t arrsize, bool hasnan = false);
XSS_HIDE_SYMBOL std::vector<size_t> argsort(T *arr,
size_t arrsize,
bool hasnan = false,
bool descending = false);
// argselect
template <typename T>
XSS_HIDE_SYMBOL std::vector<size_t>
Expand Down
15 changes: 12 additions & 3 deletions lib/x86simdsort-scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,21 @@ namespace scalar {
xss::utils::get_cmp_func<T>(hasnan, reversed));
}
template <typename T>
std::vector<size_t> argsort(T *arr, size_t arrsize, bool hasnan)
std::vector<size_t>
argsort(T *arr, size_t arrsize, bool hasnan, bool reversed)
{
UNUSED(hasnan);
std::vector<size_t> arg(arrsize);
std::iota(arg.begin(), arg.end(), 0);
std::sort(arg.begin(), arg.end(), compare_arg<T, std::less<T>>(arr));
if (reversed) {
std::sort(arg.begin(),
arg.end(),
compare_arg<T, std::greater<T>>(arr));
}
else {
std::sort(
arg.begin(), arg.end(), compare_arg<T, std::less<T>>(arr));
}
return arg;
}
template <typename T>
Expand All @@ -93,7 +102,7 @@ namespace scalar {
template <typename T1, typename T2>
void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan)
{
std::vector<size_t> arg = argsort(key, arrsize, hasnan);
std::vector<size_t> arg = argsort(key, arrsize, hasnan, false);
utils::apply_permutation_in_place(key, arg);
utils::apply_permutation_in_place(val, arg);
}
Expand Down
5 changes: 3 additions & 2 deletions lib/x86simdsort-skx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
avx512_partial_qsort(arr, k, arrsize, hasnan, descending); \
} \
template <> \
std::vector<size_t> argsort(type *arr, size_t arrsize, bool hasnan) \
std::vector<size_t> argsort( \
type *arr, size_t arrsize, bool hasnan, bool descending) \
{ \
return avx512_argsort(arr, arrsize, hasnan); \
return avx512_argsort(arr, arrsize, hasnan, descending); \
} \
template <> \
std::vector<size_t> argselect( \
Expand Down
8 changes: 5 additions & 3 deletions lib/x86simdsort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,14 @@ namespace x86simdsort {
}

#define DECLARE_INTERNAL_argsort(TYPE) \
static std::vector<size_t> (*internal_argsort##TYPE)(TYPE *, size_t, bool) \
static std::vector<size_t> (*internal_argsort##TYPE)( \
TYPE *, size_t, bool, bool) \
= NULL; \
template <> \
std::vector<size_t> argsort(TYPE *arr, size_t arrsize, bool hasnan) \
std::vector<size_t> argsort( \
TYPE *arr, size_t arrsize, bool hasnan, bool descending) \
{ \
return (*internal_argsort##TYPE)(arr, arrsize, hasnan); \
return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending); \
}

#define DECLARE_INTERNAL_argselect(TYPE) \
Expand Down
2 changes: 1 addition & 1 deletion lib/x86simdsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ XSS_EXPORT_SYMBOL void partial_qsort(T *arr,
// argsort
template <typename T>
XSS_EXPORT_SYMBOL std::vector<size_t>
argsort(T *arr, size_t arrsize, bool hasnan = false);
argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false);

// argselect
template <typename T>
Expand Down
20 changes: 10 additions & 10 deletions src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ Equivalent to `qsort` in
`std::sort` in [C++](https://en.cppreference.com/w/cpp/algorithm/sort).

```cpp
void avx512_qsort<T>(T* arr, size_t arrsize, bool hasnan = false);
void avx2_qsort<T>(T* arr, size_t arrsize, bool hasnan = false);
void avx512_qsort<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false);
void avx2_qsort<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false);
```
Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`,
`float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support
Expand All @@ -30,8 +30,8 @@ Equivalent to `std::nth_element` in


```cpp
void avx512_qselect<T>(T* arr, size_t arrsize, bool hasnan = false);
void avx2_qselect<T>(T* arr, size_t arrsize, bool hasnan = false);
void avx512_qselect<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false);
void avx2_qselect<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false);
```
Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`,
`float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support
Expand All @@ -46,8 +46,8 @@ Equivalent to `std::partial_sort` in


```cpp
void avx512_partial_qsort<T>(T* arr, size_t arrsize, bool hasnan = false)
void avx2_partial_qsort<T>(T* arr, size_t arrsize, bool hasnan = false)
void avx512_partial_qsort<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false)
void avx2_partial_qsort<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false)
```
Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`,
`float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support
Expand All @@ -61,8 +61,8 @@ Equivalent to `np.argsort` in
[NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argsort.html).

```cpp
std::vector<size_t> arg = avx512_argsort<T>(T* arr, size_t arrsize);
void avx512_argsort<T>(T* arr, size_t *arg, size_t arrsize);
std::vector<size_t> arg = avx512_argsort<T>(T* arr, size_t arrsize, bool hasnan = false, bool descending = false);
void avx512_argsort<T>(T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false);
```
Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and
`double`.
Expand All @@ -74,8 +74,8 @@ Equivalent to `np.argselect` in
[NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argpartition.html).

```cpp
std::vector<size_t> arg = avx512_argsort<T>(T* arr, size_t arrsize);
void avx512_argsort<T>(T* arr, size_t *arg, size_t arrsize);
std::vector<size_t> arg = avx512_argselect<T>(T* arr, size_t k, size_t arrsize);
void avx512_argselect<T>(T* arr, size_t *arg, size_t k, size_t arrsize);
```
Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and
`double`.
Expand Down
36 changes: 26 additions & 10 deletions src/xss-common-argsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,11 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,

/* argsort methods for 32-bit and 64-bit dtypes */
template <typename T>
X86_SIMD_SORT_INLINE void
avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
X86_SIMD_SORT_INLINE void avx512_argsort(T *arr,
arrsize_t *arg,
arrsize_t arrsize,
bool hasnan = false,
bool descending = false)
{
/* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
Expand All @@ -558,29 +561,37 @@ avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
if constexpr (std::is_floating_point_v<T>) {
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
std_argsort_withnan(arr, arg, 0, arrsize);

if (descending) { std::reverse(arg, arg + arrsize); }

return;
}
}
UNUSED(hasnan);
argsort_64bit_<vectype, argtype>(
arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));

if (descending) { std::reverse(arg, arg + arrsize); }
}
}

template <typename T>
X86_SIMD_SORT_INLINE std::vector<arrsize_t>
avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false)
X86_SIMD_SORT_INLINE std::vector<arrsize_t> avx512_argsort(
T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false)
{
std::vector<arrsize_t> indices(arrsize);
std::iota(indices.begin(), indices.end(), 0);
avx512_argsort<T>(arr, indices.data(), arrsize, hasnan);
avx512_argsort<T>(arr, indices.data(), arrsize, hasnan, descending);
return indices;
}

/* argsort methods for 32-bit and 64-bit dtypes */
template <typename T>
X86_SIMD_SORT_INLINE void
avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
X86_SIMD_SORT_INLINE void avx2_argsort(T *arr,
arrsize_t *arg,
arrsize_t arrsize,
bool hasnan = false,
bool descending = false)
{
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
avx2_half_vector<T>,
Expand All @@ -594,22 +605,27 @@ avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
if constexpr (std::is_floating_point_v<T>) {
if ((hasnan) && (array_has_nan<vectype>(arr, arrsize))) {
std_argsort_withnan(arr, arg, 0, arrsize);

if (descending) { std::reverse(arg, arg + arrsize); }

return;
}
}
UNUSED(hasnan);
argsort_64bit_<vectype, argtype>(
arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));

if (descending) { std::reverse(arg, arg + arrsize); }
}
}

template <typename T>
X86_SIMD_SORT_INLINE std::vector<arrsize_t>
avx2_argsort(T *arr, arrsize_t arrsize, bool hasnan = false)
X86_SIMD_SORT_INLINE std::vector<arrsize_t> avx2_argsort(
T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false)
{
std::vector<arrsize_t> indices(arrsize);
std::iota(indices.begin(), indices.end(), 0);
avx2_argsort<T>(arr, indices.data(), arrsize, hasnan);
avx2_argsort<T>(arr, indices.data(), arrsize, hasnan, descending);
return indices;
}

Expand Down
24 changes: 22 additions & 2 deletions tests/test-qsort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ TYPED_TEST_P(simdsort, test_qsort_descending)
}
}

TYPED_TEST_P(simdsort, test_argsort)
TYPED_TEST_P(simdsort, test_argsort_ascending)
{
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
Expand All @@ -89,6 +89,25 @@ TYPED_TEST_P(simdsort, test_argsort)
}
}

TYPED_TEST_P(simdsort, test_argsort_descending)
{
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
for (auto size : this->arrsize) {
std::vector<TypeParam> arr = get_array<TypeParam>(type, size);
std::vector<TypeParam> sortedarr = arr;
std::sort(sortedarr.begin(),
sortedarr.end(),
compare<TypeParam, std::greater<TypeParam>>());
auto arg = x86simdsort::argsort(
arr.data(), arr.size(), hasnan, true);
IS_ARG_SORTED(sortedarr, arr, arg, type);
arr.clear();
arg.clear();
}
}
}

TYPED_TEST_P(simdsort, test_qselect_ascending)
{
for (auto type : this->arrtype) {
Expand Down Expand Up @@ -241,7 +260,8 @@ TYPED_TEST_P(simdsort, test_comparator)
REGISTER_TYPED_TEST_SUITE_P(simdsort,
test_qsort_ascending,
test_qsort_descending,
test_argsort,
test_argsort_ascending,
test_argsort_descending,
test_argselect,
test_qselect_ascending,
test_qselect_descending,
Expand Down