From 93ef8c06f4a28ab9af8825f6202ccb548f4842b7 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 2 Apr 2024 10:26:52 -0700 Subject: [PATCH 1/2] Added descending sort to argsort --- README.md | 8 ++++---- benchmarks/bench-argsort.hpp | 17 +++++++++++++++++ lib/x86simdsort-avx2.cpp | 4 ++-- lib/x86simdsort-internal.h | 6 +++--- lib/x86simdsort-scalar.h | 10 +++++++--- lib/x86simdsort-skx.cpp | 4 ++-- lib/x86simdsort.cpp | 6 +++--- lib/x86simdsort.h | 2 +- src/README.md | 20 ++++++++++---------- src/xss-common-argsort.h | 32 +++++++++++++++++++++++++------- tests/test-qsort.cpp | 23 +++++++++++++++++++++-- 11 files changed, 95 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 77401446..924ee095 100644 --- a/README.md +++ b/README.md @@ -36,9 +36,9 @@ how fast this is relative to `std::sort`. ## Sort an array of built-in integers and floats ```cpp -void x86simdsort::qsort(T* arr, size_t size, bool hasnan); -void x86simdsort::qselect(T* arr, size_t k, size_t size, bool hasnan); -void x86simdsort::partial_qsort(T* arr, size_t k, size_t size, bool hasnan); +void x86simdsort::qsort(T* arr, size_t size, bool hasnan, bool descending); +void x86simdsort::qselect(T* arr, size_t k, size_t size, bool hasnan, bool descending); +void x86simdsort::partial_qsort(T* arr, size_t k, size_t size, bool hasnan, bool descending); ``` Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t, int32_t, double, uint64_t, int64_t]` @@ -53,7 +53,7 @@ data types. ## Arg sort routines on arrays ```cpp -std::vector arg = x86simdsort::argsort(T* arr, size_t size, bool hasnan); +std::vector arg = x86simdsort::argsort(T* arr, size_t size, bool hasnan, bool descending); std::vector arg = x86simdsort::argselect(T* arr, size_t k, size_t size, bool hasnan); ``` Supported datatypes: `T` $\in$ `[_Float16, uint16_t, int16_t, float, uint32_t, diff --git a/benchmarks/bench-argsort.hpp b/benchmarks/bench-argsort.hpp index 419bed13..4a2413f7 100644 --- a/benchmarks/bench-argsort.hpp +++ b/benchmarks/bench-argsort.hpp @@ -45,6 +45,22 @@ static void simdargsort(benchmark::State &state, Args &&...args) } } +template +static void simd_revargsort(benchmark::State &state, Args &&...args) +{ + // get args + auto args_tuple = std::make_tuple(std::move(args)...); + size_t arrsize = std::get<0>(args_tuple); + std::string arrtype = std::get<1>(args_tuple); + // set up array + std::vector arr = get_array(arrtype, arrsize); + std::vector inx; + // benchmark + for (auto _ : state) { + inx = x86simdsort::argsort(arr.data(), arrsize, false, true); + } +} + template static void simd_ordern_argsort(benchmark::State &state, Args &&...args) { @@ -68,6 +84,7 @@ static void simd_ordern_argsort(benchmark::State &state, Args &&...args) #define BENCH_BOTH(type) \ BENCH_SORT(simdargsort, type) \ + BENCH_SORT(simd_revargsort, type) \ BENCH_SORT(simd_ordern_argsort, type) \ BENCH_SORT(scalarargsort, type) diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 345653d9..0754a640 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -24,9 +24,9 @@ avx2_partial_qsort(arr, k, arrsize, hasnan, descending); \ } \ template <> \ - std::vector argsort(type *arr, size_t arrsize, bool hasnan) \ + std::vector argsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ { \ - return avx2_argsort(arr, arrsize, hasnan); \ + return avx2_argsort(arr, arrsize, hasnan, descending); \ } \ template <> \ std::vector argselect( \ diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index dad32b91..5a5682e6 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -31,7 +31,7 @@ namespace avx512 { // argsort template XSS_HIDE_SYMBOL std::vector - argsort(T *arr, size_t arrsize, bool hasnan = false); + argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // argselect template XSS_HIDE_SYMBOL std::vector @@ -63,7 +63,7 @@ namespace avx2 { // argsort template XSS_HIDE_SYMBOL std::vector - argsort(T *arr, size_t arrsize, bool hasnan = false); + argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // argselect template XSS_HIDE_SYMBOL std::vector @@ -95,7 +95,7 @@ namespace scalar { // argsort template XSS_HIDE_SYMBOL std::vector - argsort(T *arr, size_t arrsize, bool hasnan = false); + argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // argselect template XSS_HIDE_SYMBOL std::vector diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index 6afc7287..8b4ddb09 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -70,12 +70,16 @@ namespace scalar { xss::utils::get_cmp_func(hasnan, reversed)); } template - std::vector argsort(T *arr, size_t arrsize, bool hasnan) + std::vector argsort(T *arr, size_t arrsize, bool hasnan, bool reversed) { UNUSED(hasnan); std::vector arg(arrsize); std::iota(arg.begin(), arg.end(), 0); - std::sort(arg.begin(), arg.end(), compare_arg>(arr)); + if (reversed){ + std::sort(arg.begin(), arg.end(), compare_arg>(arr)); + }else{ + std::sort(arg.begin(), arg.end(), compare_arg>(arr)); + } return arg; } template @@ -93,7 +97,7 @@ namespace scalar { template void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan) { - std::vector arg = argsort(key, arrsize, hasnan); + std::vector arg = argsort(key, arrsize, hasnan, false); utils::apply_permutation_in_place(key, arg); utils::apply_permutation_in_place(val, arg); } diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index 4a1c2a9f..811c6d8c 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -24,9 +24,9 @@ avx512_partial_qsort(arr, k, arrsize, hasnan, descending); \ } \ template <> \ - std::vector argsort(type *arr, size_t arrsize, bool hasnan) \ + std::vector argsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ { \ - return avx512_argsort(arr, arrsize, hasnan); \ + return avx512_argsort(arr, arrsize, hasnan, descending); \ } \ template <> \ std::vector argselect( \ diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 21c8b34f..6d202c34 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -86,12 +86,12 @@ namespace x86simdsort { } #define DECLARE_INTERNAL_argsort(TYPE) \ - static std::vector (*internal_argsort##TYPE)(TYPE *, size_t, bool) \ + static std::vector (*internal_argsort##TYPE)(TYPE *, size_t, bool, bool) \ = NULL; \ template <> \ - std::vector argsort(TYPE *arr, size_t arrsize, bool hasnan) \ + std::vector argsort(TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ { \ - return (*internal_argsort##TYPE)(arr, arrsize, hasnan); \ + return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending); \ } #define DECLARE_INTERNAL_argselect(TYPE) \ diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index 42d5247f..0a85f5ea 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -36,7 +36,7 @@ XSS_EXPORT_SYMBOL void partial_qsort(T *arr, // argsort template XSS_EXPORT_SYMBOL std::vector -argsort(T *arr, size_t arrsize, bool hasnan = false); +argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // argselect template diff --git a/src/README.md b/src/README.md index 8a56658f..8030e90e 100644 --- a/src/README.md +++ b/src/README.md @@ -13,8 +13,8 @@ Equivalent to `qsort` in `std::sort` in [C++](https://en.cppreference.com/w/cpp/algorithm/sort). ```cpp -void avx512_qsort(T* arr, size_t arrsize, bool hasnan = false); -void avx2_qsort(T* arr, size_t arrsize, bool hasnan = false); +void avx512_qsort(T* arr, size_t arrsize, bool hasnan = false, bool descending = false); +void avx2_qsort(T* arr, size_t arrsize, bool hasnan = false, bool descending = false); ``` Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support @@ -30,8 +30,8 @@ Equivalent to `std::nth_element` in ```cpp -void avx512_qselect(T* arr, size_t arrsize, bool hasnan = false); -void avx2_qselect(T* arr, size_t arrsize, bool hasnan = false); +void avx512_qselect(T* arr, size_t arrsize, bool hasnan = false, bool descending = false); +void avx2_qselect(T* arr, size_t arrsize, bool hasnan = false, bool descending = false); ``` Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support @@ -46,8 +46,8 @@ Equivalent to `std::partial_sort` in ```cpp -void avx512_partial_qsort(T* arr, size_t arrsize, bool hasnan = false) -void avx2_partial_qsort(T* arr, size_t arrsize, bool hasnan = false) +void avx512_partial_qsort(T* arr, size_t arrsize, bool hasnan = false, bool descending = false) +void avx2_partial_qsort(T* arr, size_t arrsize, bool hasnan = false, bool descending = false) ``` Supported datatypes: `uint16_t`, `int16_t`, `_Float16`, `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. AVX2 versions currently support @@ -61,8 +61,8 @@ Equivalent to `np.argsort` in [NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argsort.html). ```cpp -std::vector arg = avx512_argsort(T* arr, size_t arrsize); -void avx512_argsort(T* arr, size_t *arg, size_t arrsize); +std::vector arg = avx512_argsort(T* arr, size_t arrsize, bool hasnan = false, bool descending = false); +void avx512_argsort(T* arr, size_t *arg, size_t arrsize, bool hasnan = false, bool descending = false); ``` Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. @@ -74,8 +74,8 @@ Equivalent to `np.argselect` in [NumPy](https://numpy.org/doc/stable/reference/generated/numpy.argpartition.html). ```cpp -std::vector arg = avx512_argsort(T* arr, size_t arrsize); -void avx512_argsort(T* arr, size_t *arg, size_t arrsize); +std::vector arg = avx512_argselect(T* arr, size_t k, size_t arrsize); +void avx512_argselect(T* arr, size_t *arg, size_t k, size_t arrsize); ``` Supported datatypes: `uint32_t`, `int32_t`, `float`, `uint64_t`, `int64_t` and `double`. diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index 46ce7ef7..df807656 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -542,7 +542,7 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, /* argsort methods for 32-bit and 64-bit dtypes */ template X86_SIMD_SORT_INLINE void -avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) +avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false, bool descending = false) { /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ using vectype = typename std::conditional) { if ((hasnan) && (array_has_nan(arr, arrsize))) { std_argsort_withnan(arr, arg, 0, arrsize); + + if (descending){ + std::reverse(arg, arg + arrsize); + } + return; } } UNUSED(hasnan); argsort_64bit_( arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + + if (descending){ + std::reverse(arg, arg + arrsize); + } } } template X86_SIMD_SORT_INLINE std::vector -avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) +avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false) { std::vector indices(arrsize); std::iota(indices.begin(), indices.end(), 0); - avx512_argsort(arr, indices.data(), arrsize, hasnan); + avx512_argsort(arr, indices.data(), arrsize, hasnan, descending); return indices; } /* argsort methods for 32-bit and 64-bit dtypes */ template X86_SIMD_SORT_INLINE void -avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) +avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false, bool descending = false) { using vectype = typename std::conditional, @@ -594,22 +603,31 @@ avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false) if constexpr (std::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { std_argsort_withnan(arr, arg, 0, arrsize); + + if (descending){ + std::reverse(arg, arg + arrsize); + } + return; } } UNUSED(hasnan); argsort_64bit_( arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + + if (descending){ + std::reverse(arg, arg + arrsize); + } } } template X86_SIMD_SORT_INLINE std::vector -avx2_argsort(T *arr, arrsize_t arrsize, bool hasnan = false) +avx2_argsort(T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false) { std::vector indices(arrsize); std::iota(indices.begin(), indices.end(), 0); - avx2_argsort(arr, indices.data(), arrsize, hasnan); + avx2_argsort(arr, indices.data(), arrsize, hasnan, descending); return indices; } @@ -631,7 +649,7 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, ymm_vector, zmm_vector>::type; - if (arrsize > 1) { + if (arrsize > 1) { if constexpr (std::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { std_argselect_withnan(arr, arg, k, 0, arrsize); diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index 5d4ba587..f1651390 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -71,7 +71,7 @@ TYPED_TEST_P(simdsort, test_qsort_descending) } } -TYPED_TEST_P(simdsort, test_argsort) +TYPED_TEST_P(simdsort, test_argsort_ascending) { for (auto type : this->arrtype) { bool hasnan = (type == "rand_with_nan") ? true : false; @@ -89,6 +89,24 @@ TYPED_TEST_P(simdsort, test_argsort) } } +TYPED_TEST_P(simdsort, test_argsort_descending) +{ + for (auto type : this->arrtype) { + bool hasnan = (type == "rand_with_nan") ? true : false; + for (auto size : this->arrsize) { + std::vector arr = get_array(type, size); + std::vector sortedarr = arr; + std::sort(sortedarr.begin(), + sortedarr.end(), + compare>()); + auto arg = x86simdsort::argsort(arr.data(), arr.size(), hasnan, true); + IS_ARG_SORTED(sortedarr, arr, arg, type); + arr.clear(); + arg.clear(); + } + } +} + TYPED_TEST_P(simdsort, test_qselect_ascending) { for (auto type : this->arrtype) { @@ -241,7 +259,8 @@ TYPED_TEST_P(simdsort, test_comparator) REGISTER_TYPED_TEST_SUITE_P(simdsort, test_qsort_ascending, test_qsort_descending, - test_argsort, + test_argsort_ascending, + test_argsort_descending, test_argselect, test_qselect_ascending, test_qselect_descending, From d58de527625d98fd3dd85e3be50b31d6bffcfa7e Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 2 Apr 2024 10:29:57 -0700 Subject: [PATCH 2/2] Formatting --- lib/x86simdsort-avx2.cpp | 3 ++- lib/x86simdsort-internal.h | 18 ++++++++----- lib/x86simdsort-scalar.h | 15 +++++++---- lib/x86simdsort-skx.cpp | 3 ++- lib/x86simdsort.cpp | 6 +++-- src/xss-common-argsort.h | 52 ++++++++++++++++++-------------------- tests/test-qsort.cpp | 3 ++- 7 files changed, 57 insertions(+), 43 deletions(-) diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 0754a640..e10fc164 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -24,7 +24,8 @@ avx2_partial_qsort(arr, k, arrsize, hasnan, descending); \ } \ template <> \ - std::vector argsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ + std::vector argsort( \ + type *arr, size_t arrsize, bool hasnan, bool descending) \ { \ return avx2_argsort(arr, arrsize, hasnan, descending); \ } \ diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index 5a5682e6..1bbb4067 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -30,8 +30,10 @@ namespace avx512 { bool descending = false); // argsort template - XSS_HIDE_SYMBOL std::vector - argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL std::vector argsort(T *arr, + size_t arrsize, + bool hasnan = false, + bool descending = false); // argselect template XSS_HIDE_SYMBOL std::vector @@ -62,8 +64,10 @@ namespace avx2 { bool descending = false); // argsort template - XSS_HIDE_SYMBOL std::vector - argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL std::vector argsort(T *arr, + size_t arrsize, + bool hasnan = false, + bool descending = false); // argselect template XSS_HIDE_SYMBOL std::vector @@ -94,8 +98,10 @@ namespace scalar { bool descending = false); // argsort template - XSS_HIDE_SYMBOL std::vector - argsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL std::vector argsort(T *arr, + size_t arrsize, + bool hasnan = false, + bool descending = false); // argselect template XSS_HIDE_SYMBOL std::vector diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index 8b4ddb09..e5ac6ab6 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -70,15 +70,20 @@ namespace scalar { xss::utils::get_cmp_func(hasnan, reversed)); } template - std::vector argsort(T *arr, size_t arrsize, bool hasnan, bool reversed) + std::vector + argsort(T *arr, size_t arrsize, bool hasnan, bool reversed) { UNUSED(hasnan); std::vector arg(arrsize); std::iota(arg.begin(), arg.end(), 0); - if (reversed){ - std::sort(arg.begin(), arg.end(), compare_arg>(arr)); - }else{ - std::sort(arg.begin(), arg.end(), compare_arg>(arr)); + if (reversed) { + std::sort(arg.begin(), + arg.end(), + compare_arg>(arr)); + } + else { + std::sort( + arg.begin(), arg.end(), compare_arg>(arr)); } return arg; } diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index 811c6d8c..8b154d4e 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -24,7 +24,8 @@ avx512_partial_qsort(arr, k, arrsize, hasnan, descending); \ } \ template <> \ - std::vector argsort(type *arr, size_t arrsize, bool hasnan, bool descending) \ + std::vector argsort( \ + type *arr, size_t arrsize, bool hasnan, bool descending) \ { \ return avx512_argsort(arr, arrsize, hasnan, descending); \ } \ diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 6d202c34..0c16f148 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -86,10 +86,12 @@ namespace x86simdsort { } #define DECLARE_INTERNAL_argsort(TYPE) \ - static std::vector (*internal_argsort##TYPE)(TYPE *, size_t, bool, bool) \ + static std::vector (*internal_argsort##TYPE)( \ + TYPE *, size_t, bool, bool) \ = NULL; \ template <> \ - std::vector argsort(TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ + std::vector argsort( \ + TYPE *arr, size_t arrsize, bool hasnan, bool descending) \ { \ return (*internal_argsort##TYPE)(arr, arrsize, hasnan, descending); \ } diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index df807656..b97dd0d0 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -541,8 +541,11 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr, /* argsort methods for 32-bit and 64-bit dtypes */ template -X86_SIMD_SORT_INLINE void -avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false, bool descending = false) +X86_SIMD_SORT_INLINE void avx512_argsort(T *arr, + arrsize_t *arg, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) { /* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ using vectype = typename std::conditional) { if ((hasnan) && (array_has_nan(arr, arrsize))) { std_argsort_withnan(arr, arg, 0, arrsize); - - if (descending){ - std::reverse(arg, arg + arrsize); - } - + + if (descending) { std::reverse(arg, arg + arrsize); } + return; } } UNUSED(hasnan); argsort_64bit_( arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - - if (descending){ - std::reverse(arg, arg + arrsize); - } + + if (descending) { std::reverse(arg, arg + arrsize); } } } template -X86_SIMD_SORT_INLINE std::vector -avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false) +X86_SIMD_SORT_INLINE std::vector avx512_argsort( + T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false) { std::vector indices(arrsize); std::iota(indices.begin(), indices.end(), 0); @@ -588,8 +587,11 @@ avx512_argsort(T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = /* argsort methods for 32-bit and 64-bit dtypes */ template -X86_SIMD_SORT_INLINE void -avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false, bool descending = false) +X86_SIMD_SORT_INLINE void avx2_argsort(T *arr, + arrsize_t *arg, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) { using vectype = typename std::conditional, @@ -603,27 +605,23 @@ avx2_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false, boo if constexpr (std::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { std_argsort_withnan(arr, arg, 0, arrsize); - - if (descending){ - std::reverse(arg, arg + arrsize); - } - + + if (descending) { std::reverse(arg, arg + arrsize); } + return; } } UNUSED(hasnan); argsort_64bit_( arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - - if (descending){ - std::reverse(arg, arg + arrsize); - } + + if (descending) { std::reverse(arg, arg + arrsize); } } } template -X86_SIMD_SORT_INLINE std::vector -avx2_argsort(T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false) +X86_SIMD_SORT_INLINE std::vector avx2_argsort( + T *arr, arrsize_t arrsize, bool hasnan = false, bool descending = false) { std::vector indices(arrsize); std::iota(indices.begin(), indices.end(), 0); @@ -649,7 +647,7 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr, ymm_vector, zmm_vector>::type; - if (arrsize > 1) { + if (arrsize > 1) { if constexpr (std::is_floating_point_v) { if ((hasnan) && (array_has_nan(arr, arrsize))) { std_argselect_withnan(arr, arg, k, 0, arrsize); diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index f1651390..5ebd018f 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -99,7 +99,8 @@ TYPED_TEST_P(simdsort, test_argsort_descending) std::sort(sortedarr.begin(), sortedarr.end(), compare>()); - auto arg = x86simdsort::argsort(arr.data(), arr.size(), hasnan, true); + auto arg = x86simdsort::argsort( + arr.data(), arr.size(), hasnan, true); IS_ARG_SORTED(sortedarr, arr, arg, type); arr.clear(); arg.clear();