From 16fe62ce700769f006fe8f21f21197f74ff31edb Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Fri, 26 Apr 2024 09:11:22 -0700 Subject: [PATCH 1/8] Improves test for kv-sort logic --- tests/test-keyvalue.cpp | 107 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 102 insertions(+), 5 deletions(-) diff --git a/tests/test-keyvalue.cpp b/tests/test-keyvalue.cpp index fda9130d..f7d26d64 100644 --- a/tests/test-keyvalue.cpp +++ b/tests/test-keyvalue.cpp @@ -29,6 +29,57 @@ class simdkvsort : public ::testing::Test { TYPED_TEST_SUITE_P(simdkvsort); +template +bool same_values(T* v1, T* v2, size_t size){ + // Checks that the values are the same except (maybe) their ordering + auto cmp_eq = compare>(); + + // TODO hardcoding hasnan to true doesn't break anything right? + x86simdsort::qsort(v1, size, true); + x86simdsort::qsort(v2, size, true); + + for (size_t i = 0; i < size; i++){ + if (!cmp_eq(v1[i], v2[i])){ + return false; + } + } + + return true; +} + +template +bool kv_equivalent(T1* keys_comp, T2* vals_comp, T1* keys_ref, T2* vals_ref, size_t size){ + auto cmp_eq = compare>(); + + // First check keys are exactly identical + for (size_t i = 0; i < size; i++){ + if (!cmp_eq(keys_comp[i], keys_ref[i])){ + return false; + } + } + + size_t i_start = 0; + T1 key_start = keys_comp[0]; + // Loop through all identical keys in a block, then compare the sets of values to make sure they are identical + // We need the index after the loop + size_t i = 0; + for (; i < size; i++){ + if (!cmp_eq(keys_comp[i], key_start)){ + // Check that every value in + + if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)){ + return false; + } + + // Now setup the start variables to begin gathering keys for the next group + i_start = i; + key_start = keys_comp[i]; + } + } + + return true; +} + TYPED_TEST_P(simdkvsort, test_kvsort) { using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; @@ -43,10 +94,10 @@ TYPED_TEST_P(simdkvsort, test_kvsort) x86simdsort::keyvalue_qsort(key.data(), val.data(), size, hasnan); xss::scalar::keyvalue_qsort( key_bckp.data(), val_bckp.data(), size, hasnan); - ASSERT_EQ(key, key_bckp); - const bool hasDuplicates - = std::adjacent_find(key.begin(), key.end()) != key.end(); - if (!hasDuplicates) { ASSERT_EQ(val, val_bckp); } + + bool is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size); + ASSERT_EQ(is_kv_equivalent, true); + key.clear(); val.clear(); key_bckp.clear(); @@ -55,7 +106,53 @@ TYPED_TEST_P(simdkvsort, test_kvsort) } } -REGISTER_TYPED_TEST_SUITE_P(simdkvsort, test_kvsort); +TYPED_TEST_P(simdkvsort, test_validator) +{ + // Tests a few edge cases to verify the tests are working correctly and identifying it as functional + using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; + using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; + + bool is_kv_equivalent; + + std::vector key = {0, 0, 1, 1}; + std::vector val = {1, 2, 3, 4}; + std::vector key_bckp = key; + std::vector val_bckp = val; + + // Duplicate keys, but otherwise exactly identical + is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); + ASSERT_EQ(is_kv_equivalent, true); + + val = {2,1,4,3}; + + // Now values are backwards, but this is still fine + is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); + ASSERT_EQ(is_kv_equivalent, true); + + val = {1,3,2,4}; + + // Now values are mixed up, should fail + is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); + ASSERT_EQ(is_kv_equivalent, false); + + val = {1,2,3,4}; + key = {0,0,0,0}; + + // Now keys are messed up, should fail + is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); + ASSERT_EQ(is_kv_equivalent, false); + + key = {0,0,0,0,0,0}; + key_bckp = key; + val_bckp = {1,2,3,4,5,6}; + val = {4,3,1,6,5,2}; + + // All keys identical, simply reordered values + is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); + ASSERT_EQ(is_kv_equivalent, true); +} + +REGISTER_TYPED_TEST_SUITE_P(simdkvsort, test_kvsort, test_validator); #define CREATE_TUPLES(type) \ std::tuple, std::tuple, \ From 617589292fc3c69039b05ee2ffaca525bf67d195 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Fri, 26 Apr 2024 15:57:54 -0700 Subject: [PATCH 2/8] Testing for kvselect and kvpartial_sort --- tests/test-keyvalue.cpp | 67 ++++++++++++++++++++++++++++++++++++++++- tests/test-qsort.cpp | 3 +- utils/custom-compare.h | 5 +++ 3 files changed, 72 insertions(+), 3 deletions(-) diff --git a/tests/test-keyvalue.cpp b/tests/test-keyvalue.cpp index f7d26d64..51e17591 100644 --- a/tests/test-keyvalue.cpp +++ b/tests/test-keyvalue.cpp @@ -6,6 +6,7 @@ #include "rand_array.h" #include "x86simdsort.h" #include "x86simdsort-scalar.h" +#include "test-qsort-common.h" #include template @@ -106,6 +107,70 @@ TYPED_TEST_P(simdkvsort, test_kvsort) } } +TYPED_TEST_P(simdkvsort, test_kvselect) +{ + using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; + using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; + for (auto type : this->arrtype) { + bool hasnan = (type == "rand_with_nan") ? true : false; + for (auto size : this->arrsize) { + size_t k = rand() % size; + + std::vector key = get_array(type, size); + std::vector val = get_array(type, size); + std::vector key_bckp = key; + std::vector val_bckp = val; + + xss::scalar::keyvalue_qsort( + key_bckp.data(), val_bckp.data(), size, hasnan); + + // Test select by using it as part of partial_sort + x86simdsort::keyvalue_select(key.data(), val.data(), k, size, hasnan); + IS_ARR_PARTITIONED(key, k, key_bckp[k], type); + xss::scalar::keyvalue_qsort(key.data(), val.data(), k, hasnan); + + + bool is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), k); + ASSERT_EQ(is_kv_equivalent, true); + + key.clear(); + val.clear(); + key_bckp.clear(); + val_bckp.clear(); + } + } +} + +TYPED_TEST_P(simdkvsort, test_kvpartial_sort) +{ + using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; + using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; + for (auto type : this->arrtype) { + bool hasnan = (type == "rand_with_nan") ? true : false; + for (auto size : this->arrsize) { + size_t k = rand() % size; + + std::vector key = get_array(type, size); + std::vector val = get_array(type, size); + std::vector key_bckp = key; + std::vector val_bckp = val; + x86simdsort::keyvalue_partial_sort(key.data(), val.data(), k, size, hasnan); + xss::scalar::keyvalue_qsort( + key_bckp.data(), val_bckp.data(), size, hasnan); + + IS_ARR_PARTIALSORTED(key, k, key_bckp, type); + + bool is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), k); + ASSERT_EQ(is_kv_equivalent, true); + + key.clear(); + val.clear(); + key_bckp.clear(); + val_bckp.clear(); + } + } +} + TYPED_TEST_P(simdkvsort, test_validator) { // Tests a few edge cases to verify the tests are working correctly and identifying it as functional @@ -152,7 +217,7 @@ TYPED_TEST_P(simdkvsort, test_validator) ASSERT_EQ(is_kv_equivalent, true); } -REGISTER_TYPED_TEST_SUITE_P(simdkvsort, test_kvsort, test_validator); +REGISTER_TYPED_TEST_SUITE_P(simdkvsort, test_kvsort, test_kvselect, test_kvpartial_sort, test_validator); #define CREATE_TUPLES(type) \ std::tuple, std::tuple, \ diff --git a/tests/test-qsort.cpp b/tests/test-qsort.cpp index 5ebd018f..0df7addf 100644 --- a/tests/test-qsort.cpp +++ b/tests/test-qsort.cpp @@ -181,8 +181,7 @@ TYPED_TEST_P(simdsort, test_partial_qsort_ascending) for (auto type : this->arrtype) { bool hasnan = (type == "rand_with_nan") ? true : false; for (auto size : this->arrsize) { - // k should be at least 1 - size_t k = std::max((size_t)1, rand() % size); + size_t k = rand() % size; std::vector basearr = get_array(type, size); // Ascending order diff --git a/utils/custom-compare.h b/utils/custom-compare.h index 6244bb24..f2c8d61e 100644 --- a/utils/custom-compare.h +++ b/utils/custom-compare.h @@ -1,3 +1,6 @@ +#ifndef UTILS_CUSTOM_COMPARE +#define UTILS_CUSTOM_COMPARE + #include #include #include "xss-custom-float.h" @@ -42,3 +45,5 @@ struct compare_arg { } const T *arr; }; + +#endif // UTILS_CUSTOM_COMPARE \ No newline at end of file From c2fa38ca08e2ac5ba1d3b761b24f83b5195a356f Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Fri, 26 Apr 2024 15:58:36 -0700 Subject: [PATCH 3/8] Support for key-value select and partial sort --- lib/x86simdsort-avx2.cpp | 35 ++++----- lib/x86simdsort-internal.h | 24 ++++++ lib/x86simdsort-scalar.h | 21 ++++++ lib/x86simdsort-skx.cpp | 35 ++++----- lib/x86simdsort.cpp | 62 ++++++++++++++++ lib/x86simdsort.h | 10 +++ src/x86simdsort-static-incl.h | 20 +++++ src/xss-common-keyvaluesort.hpp | 125 ++++++++++++++++++++++++++++++++ 8 files changed, 290 insertions(+), 42 deletions(-) diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 4c1123e4..91ec4aa5 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -34,37 +34,30 @@ return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \ } -#define DEFINE_KEYVALUE_METHODS(type) \ - template <> \ - void keyvalue_qsort(type *key, uint64_t *val, size_t arrsize, bool hasnan) \ - { \ - x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ - } \ - template <> \ - void keyvalue_qsort(type *key, int64_t *val, size_t arrsize, bool hasnan) \ - { \ - x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ - } \ - template <> \ - void keyvalue_qsort(type *key, double *val, size_t arrsize, bool hasnan) \ - { \ - x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ - } \ +#define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ template <> \ - void keyvalue_qsort(type *key, uint32_t *val, size_t arrsize, bool hasnan) \ + void keyvalue_qsort(type1 *key, type2 *val, size_t arrsize, bool hasnan) \ { \ x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ } \ template <> \ - void keyvalue_qsort(type *key, int32_t *val, size_t arrsize, bool hasnan) \ + void keyvalue_select(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan) \ { \ - x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ + x86simdsortStatic::keyvalue_select(key, val, k, arrsize, hasnan); \ } \ template <> \ - void keyvalue_qsort(type *key, float *val, size_t arrsize, bool hasnan) \ + void keyvalue_partial_sort(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan) \ { \ - x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ + x86simdsortStatic::keyvalue_partial_sort(key, val, k, arrsize, hasnan); \ } + +#define DEFINE_KEYVALUE_METHODS(type) \ + DEFINE_KEYVALUE_METHODS_BASE(type, uint64_t) \ + DEFINE_KEYVALUE_METHODS_BASE(type, int64_t) \ + DEFINE_KEYVALUE_METHODS_BASE(type, double) \ + DEFINE_KEYVALUE_METHODS_BASE(type, uint32_t) \ + DEFINE_KEYVALUE_METHODS_BASE(type, int32_t) \ + DEFINE_KEYVALUE_METHODS_BASE(type, float) namespace xss { namespace avx2 { diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index a74de690..7f074771 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -21,6 +21,10 @@ namespace avx512 { size_t arrsize, bool hasnan = false, bool descending = false); + // key-value select + template + XSS_EXPORT_SYMBOL void + keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false); // partial sort template XSS_HIDE_SYMBOL void partial_qsort(T *arr, @@ -28,6 +32,10 @@ namespace avx512 { size_t arrsize, bool hasnan = false, bool descending = false); + // key-value partial sort + template + XSS_EXPORT_SYMBOL void + keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false); // argsort template XSS_HIDE_SYMBOL std::vector argsort(T *arr, @@ -55,6 +63,10 @@ namespace avx2 { size_t arrsize, bool hasnan = false, bool descending = false); + // key-value select + template + XSS_EXPORT_SYMBOL void + keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false); // partial sort template XSS_HIDE_SYMBOL void partial_qsort(T *arr, @@ -62,6 +74,10 @@ namespace avx2 { size_t arrsize, bool hasnan = false, bool descending = false); + // key-value partial sort + template + XSS_EXPORT_SYMBOL void + keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false); // argsort template XSS_HIDE_SYMBOL std::vector argsort(T *arr, @@ -89,6 +105,10 @@ namespace scalar { size_t arrsize, bool hasnan = false, bool descending = false); + // key-value select + template + XSS_EXPORT_SYMBOL void + keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false); // partial sort template XSS_HIDE_SYMBOL void partial_qsort(T *arr, @@ -96,6 +116,10 @@ namespace scalar { size_t arrsize, bool hasnan = false, bool descending = false); + // key-value partial sort + template + XSS_EXPORT_SYMBOL void + keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false); // argsort template XSS_HIDE_SYMBOL std::vector argsort(T *arr, diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index e5ac6ab6..fcd64f3f 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -106,6 +106,27 @@ namespace scalar { utils::apply_permutation_in_place(key, arg); utils::apply_permutation_in_place(val, arg); } + template + void keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan) + { + if (k == 0) return; + // Note that this does a full partial sort, not just a select + std::vector arg = argsort(key, arrsize, hasnan, false); + //arg.resize(k); + + utils::apply_permutation_in_place(key, arg); + utils::apply_permutation_in_place(val, arg); + } + template + void keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan) + { + if (k == 0) return; + std::vector arg = argsort(key, arrsize, hasnan, false); + //arg.resize(k); + + utils::apply_permutation_in_place(key, arg); + utils::apply_permutation_in_place(val, arg); + } } // namespace scalar } // namespace xss diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index e51c51ed..635dd6e3 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -34,37 +34,30 @@ return x86simdsortStatic::argselect(arr, k, arrsize, hasnan); \ } -#define DEFINE_KEYVALUE_METHODS(type) \ - template <> \ - void keyvalue_qsort(type *key, uint64_t *val, size_t arrsize, bool hasnan) \ - { \ - x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ - } \ - template <> \ - void keyvalue_qsort(type *key, int64_t *val, size_t arrsize, bool hasnan) \ - { \ - x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ - } \ - template <> \ - void keyvalue_qsort(type *key, double *val, size_t arrsize, bool hasnan) \ - { \ - x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ - } \ +#define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ template <> \ - void keyvalue_qsort(type *key, uint32_t *val, size_t arrsize, bool hasnan) \ + void keyvalue_qsort(type1 *key, type2 *val, size_t arrsize, bool hasnan) \ { \ x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ } \ template <> \ - void keyvalue_qsort(type *key, int32_t *val, size_t arrsize, bool hasnan) \ + void keyvalue_select(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan) \ { \ - x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ + x86simdsortStatic::keyvalue_select(key, val, k, arrsize, hasnan); \ } \ template <> \ - void keyvalue_qsort(type *key, float *val, size_t arrsize, bool hasnan) \ + void keyvalue_partial_sort(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan) \ { \ - x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ + x86simdsortStatic::keyvalue_partial_sort(key, val, k, arrsize, hasnan); \ } + +#define DEFINE_KEYVALUE_METHODS(type) \ + DEFINE_KEYVALUE_METHODS_BASE(type, uint64_t) \ + DEFINE_KEYVALUE_METHODS_BASE(type, int64_t) \ + DEFINE_KEYVALUE_METHODS_BASE(type, double) \ + DEFINE_KEYVALUE_METHODS_BASE(type, uint32_t) \ + DEFINE_KEYVALUE_METHODS_BASE(type, int32_t) \ + DEFINE_KEYVALUE_METHODS_BASE(type, float) namespace xss { namespace avx512 { diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 2f268abc..df7d2cc2 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -160,6 +160,68 @@ namespace x86simdsort { return; \ } \ } \ + }\ + static void(CAT(CAT(*internal_kv_select_, TYPE1), TYPE2))( \ + TYPE1 *, TYPE2 *, size_t, size_t, bool) \ + = NULL; \ + template <> \ + void keyvalue_select(TYPE1 *key, TYPE2 *val, size_t k, size_t arrsize, bool hasnan) \ + { \ + (CAT(CAT(*internal_kv_select_, TYPE1), TYPE2))( \ + key, val, k, arrsize, hasnan); \ + } \ + static __attribute__((constructor)) void CAT( \ + CAT(resolve_keyvalue_select_, TYPE1), TYPE2)(void) \ + { \ + CAT(CAT(internal_kv_select_, TYPE1), TYPE2) \ + = &xss::scalar::keyvalue_select; \ + __builtin_cpu_init(); \ + std::string_view preferred_cpu = find_preferred_cpu(ISA); \ + if constexpr (dispatch_requested("avx512", ISA)) { \ + if (preferred_cpu.find("avx512") != std::string_view::npos) { \ + CAT(CAT(internal_kv_select_, TYPE1), TYPE2) \ + = &xss::avx512::keyvalue_select; \ + return; \ + } \ + } \ + if constexpr (dispatch_requested("avx2", ISA)) { \ + if (preferred_cpu.find("avx2") != std::string_view::npos) { \ + CAT(CAT(internal_kv_select_, TYPE1), TYPE2) \ + = &xss::avx2::keyvalue_select; \ + return; \ + } \ + } \ + } \ + static void(CAT(CAT(*internal_kv_partial_sort_, TYPE1), TYPE2))( \ + TYPE1 *, TYPE2 *, size_t, size_t, bool) \ + = NULL; \ + template <> \ + void keyvalue_partial_sort(TYPE1 *key, TYPE2 *val, size_t k, size_t arrsize, bool hasnan) \ + { \ + (CAT(CAT(*internal_kv_partial_sort_, TYPE1), TYPE2))( \ + key, val, k, arrsize, hasnan); \ + } \ + static __attribute__((constructor)) void CAT( \ + CAT(resolve_keyvalue_partial_sort_, TYPE1), TYPE2)(void) \ + { \ + CAT(CAT(internal_kv_partial_sort_, TYPE1), TYPE2) \ + = &xss::scalar::keyvalue_partial_sort; \ + __builtin_cpu_init(); \ + std::string_view preferred_cpu = find_preferred_cpu(ISA); \ + if constexpr (dispatch_requested("avx512", ISA)) { \ + if (preferred_cpu.find("avx512") != std::string_view::npos) { \ + CAT(CAT(internal_kv_partial_sort_, TYPE1), TYPE2) \ + = &xss::avx512::keyvalue_partial_sort; \ + return; \ + } \ + } \ + if constexpr (dispatch_requested("avx2", ISA)) { \ + if (preferred_cpu.find("avx2") != std::string_view::npos) { \ + CAT(CAT(internal_kv_partial_sort_, TYPE1), TYPE2) \ + = &xss::avx2::keyvalue_partial_sort; \ + return; \ + } \ + } \ } #define ISA_LIST(...) \ diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index 0a85f5ea..f2e09269 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -48,6 +48,16 @@ template XSS_EXPORT_SYMBOL void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false); +// keyvalue select +template +XSS_EXPORT_SYMBOL void +keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false); + +// keyvalue partial sort +template +XSS_EXPORT_SYMBOL void +keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false); + // sort an object template XSS_EXPORT_SYMBOL void object_qsort(T *arr, uint32_t arrsize, Func key_func) diff --git a/src/x86simdsort-static-incl.h b/src/x86simdsort-static-incl.h index 1f849004..9f1973cc 100644 --- a/src/x86simdsort-static-incl.h +++ b/src/x86simdsort-static-incl.h @@ -49,6 +49,14 @@ template X86_SIMD_SORT_FINLINE void keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false); +template +X86_SIMD_SORT_FINLINE void +keyvalue_select(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = false); + +template +X86_SIMD_SORT_FINLINE void +keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = false); + } // namespace x86simdsortStatic #define XSS_METHODS(ISA) \ @@ -106,6 +114,18 @@ keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false); T1 *key, T2 *val, size_t size, bool hasnan) \ { \ ISA##_qsort_kv(key, val, size, hasnan); \ + } \ + template \ + X86_SIMD_SORT_FINLINE void x86simdsortStatic::keyvalue_select( \ + T1 *key, T2 *val, size_t k, size_t size, bool hasnan) \ + { \ + ISA##_select_kv(key, val, k, size, hasnan); \ + } \ + template \ + X86_SIMD_SORT_FINLINE void x86simdsortStatic::keyvalue_partial_sort( \ + T1 *key, T2 *val, size_t k, size_t size, bool hasnan) \ + { \ + ISA##_partial_sort_kv(key, val, k, size, hasnan); \ } /* diff --git a/src/xss-common-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp index 4699b8a1..94ee6aca 100644 --- a/src/xss-common-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -401,6 +401,50 @@ X86_SIMD_SORT_INLINE void kvsort_(type1_t *keys, } } +template +X86_SIMD_SORT_INLINE void kvselect_(type1_t *keys, + type2_t *indexes, + arrsize_t pos, + arrsize_t left, + arrsize_t right, + int max_iters) +{ + /* + * Resort to std::sort if quicksort isnt making any progress + */ + if (max_iters <= 0) { + heap_sort( + keys + left, indexes + left, right - left + 1); + return; + } + /* + * Base case: use bitonic networks to sort arrays <= 128 + */ + if (right + 1 - left <= 128) { + + kvsort_n( + keys + left, indexes + left, (int32_t)(right + 1 - left)); + return; + } + + type1_t pivot = get_pivot_blocks(keys, left, right); + type1_t smallest = vtype1::type_max(); + type1_t biggest = vtype1::type_min(); + arrsize_t pivot_index = kvpartition_unrolled( + keys, indexes, left, right + 1, pivot, &smallest, &biggest); + if ((pivot != smallest) && (pos < pivot_index)) { + kvselect_( + keys, indexes, pos, left, pivot_index - 1, max_iters - 1); + } + else if ((pivot != biggest) && (pos >= pivot_index)) { + kvselect_( + keys, indexes, pos, pivot_index, right, max_iters - 1); + } +} + template @@ -445,6 +489,55 @@ xss_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan) } } +template typename full_vector, template typename half_vector> +X86_SIMD_SORT_INLINE void +xss_select_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +{ + using keytype = + typename std::conditional, + full_vector>::type; + using valtype = + typename std::conditional, + full_vector>::type; + +#ifdef XSS_TEST_KEYVALUE_BASE_CASE + int maxiters = -1; + bool minarrsize = true; +#else + int maxiters = 2 * log2(arrsize); + bool minarrsize = arrsize > 1 ? true : false; +#endif // XSS_TEST_KEYVALUE_BASE_CASE + + if (minarrsize) { + if constexpr (std::is_floating_point_v) { + arrsize_t nan_count = 0; + if (UNLIKELY(hasnan)) { + nan_count + = replace_nan_with_inf>(keys, arrsize); + } + kvselect_(keys, indexes, k, 0, arrsize - 1, maxiters); + replace_inf_with_nan(keys, arrsize, nan_count); + } + else { + UNUSED(hasnan); + kvselect_(keys, indexes, k, 0, arrsize - 1, maxiters); + } + } +} + +template typename full_vector, template typename half_vector> +X86_SIMD_SORT_INLINE void +xss_partial_sort_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +{ + if (k == 0) return; + xss_select_kv(keys, indexes, k - 1, arrsize, hasnan); + xss_qsort_kv(keys, indexes, k - 1, hasnan); +} + template X86_SIMD_SORT_INLINE void avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false) @@ -460,4 +553,36 @@ avx2_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false) xss_qsort_kv( keys, indexes, arrsize, hasnan); } + +template +X86_SIMD_SORT_INLINE void +avx512_select_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +{ + xss_select_kv( + keys, indexes, k, arrsize, hasnan); +} + +template +X86_SIMD_SORT_INLINE void +avx2_select_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +{ + xss_select_kv( + keys, indexes, k, arrsize, hasnan); +} + +template +X86_SIMD_SORT_INLINE void +avx512_partial_sort_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +{ + xss_partial_sort_kv( + keys, indexes, k, arrsize, hasnan); +} + +template +X86_SIMD_SORT_INLINE void +avx2_partial_sort_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +{ + xss_partial_sort_kv( + keys, indexes, k, arrsize, hasnan); +} #endif // AVX512_QSORT_64BIT_KV From 93b9c993e2230bce79adce438d238bcedd6da2e8 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Wed, 1 May 2024 13:32:24 -0700 Subject: [PATCH 4/8] Adds descending sort for kvsort, kvselect, and kvpartial_sort and related tests --- benchmarks/bench-keyvalue.hpp | 2 +- lib/x86simdsort-avx2.cpp | 12 ++-- lib/x86simdsort-internal.h | 16 ++--- lib/x86simdsort-scalar.h | 12 ++-- lib/x86simdsort-skx.cpp | 12 ++-- lib/x86simdsort.cpp | 18 ++--- lib/x86simdsort.h | 6 +- src/x86simdsort-static-incl.h | 18 ++--- src/xss-common-keyvaluesort.hpp | 50 +++++++++----- src/xss-common-qsort.h | 1 + tests/test-keyvalue.cpp | 112 ++++++++++++++++++++++++++++---- 11 files changed, 183 insertions(+), 76 deletions(-) diff --git a/benchmarks/bench-keyvalue.hpp b/benchmarks/bench-keyvalue.hpp index 1eaab9e9..5ed8d48a 100644 --- a/benchmarks/bench-keyvalue.hpp +++ b/benchmarks/bench-keyvalue.hpp @@ -13,7 +13,7 @@ static void scalarkvsort(benchmark::State &state, Args &&...args) std::vector key_bkp = key; // benchmark for (auto _ : state) { - xss::scalar::keyvalue_qsort(key.data(), val.data(), arrsize, false); + xss::scalar::keyvalue_qsort(key.data(), val.data(), arrsize, false, false); state.PauseTiming(); key = key_bkp; state.ResumeTiming(); diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 91ec4aa5..9bbef8cd 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -36,19 +36,19 @@ #define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ template <> \ - void keyvalue_qsort(type1 *key, type2 *val, size_t arrsize, bool hasnan) \ + void keyvalue_qsort(type1 *key, type2 *val, size_t arrsize, bool hasnan, bool descending) \ { \ - x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ + x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan, descending); \ } \ template <> \ - void keyvalue_select(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan) \ + void keyvalue_select(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - x86simdsortStatic::keyvalue_select(key, val, k, arrsize, hasnan); \ + x86simdsortStatic::keyvalue_select(key, val, k, arrsize, hasnan, descending); \ } \ template <> \ - void keyvalue_partial_sort(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan) \ + void keyvalue_partial_sort(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - x86simdsortStatic::keyvalue_partial_sort(key, val, k, arrsize, hasnan); \ + x86simdsortStatic::keyvalue_partial_sort(key, val, k, arrsize, hasnan, descending); \ } #define DEFINE_KEYVALUE_METHODS(type) \ diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index 7f074771..8ef6814c 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -13,7 +13,7 @@ namespace avx512 { // key-value quicksort template XSS_HIDE_SYMBOL void - keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false); + keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false, bool descending = false); // quickselect template XSS_HIDE_SYMBOL void qselect(T *arr, @@ -24,7 +24,7 @@ namespace avx512 { // key-value select template XSS_EXPORT_SYMBOL void - keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false); + keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // partial sort template XSS_HIDE_SYMBOL void partial_qsort(T *arr, @@ -35,7 +35,7 @@ namespace avx512 { // key-value partial sort template XSS_EXPORT_SYMBOL void - keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false); + keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector argsort(T *arr, @@ -66,7 +66,7 @@ namespace avx2 { // key-value select template XSS_EXPORT_SYMBOL void - keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false); + keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // partial sort template XSS_HIDE_SYMBOL void partial_qsort(T *arr, @@ -77,7 +77,7 @@ namespace avx2 { // key-value partial sort template XSS_EXPORT_SYMBOL void - keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false); + keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector argsort(T *arr, @@ -97,7 +97,7 @@ namespace scalar { // key-value quicksort template XSS_HIDE_SYMBOL void - keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false); + keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false, bool descending = false); // quickselect template XSS_HIDE_SYMBOL void qselect(T *arr, @@ -108,7 +108,7 @@ namespace scalar { // key-value select template XSS_EXPORT_SYMBOL void - keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false); + keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // partial sort template XSS_HIDE_SYMBOL void partial_qsort(T *arr, @@ -119,7 +119,7 @@ namespace scalar { // key-value partial sort template XSS_EXPORT_SYMBOL void - keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false); + keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector argsort(T *arr, diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index fcd64f3f..ce383f5a 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -100,28 +100,28 @@ namespace scalar { return arg; } template - void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan) + void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan, bool descending) { - std::vector arg = argsort(key, arrsize, hasnan, false); + std::vector arg = argsort(key, arrsize, hasnan, descending); utils::apply_permutation_in_place(key, arg); utils::apply_permutation_in_place(val, arg); } template - void keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan) + void keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) { if (k == 0) return; // Note that this does a full partial sort, not just a select - std::vector arg = argsort(key, arrsize, hasnan, false); + std::vector arg = argsort(key, arrsize, hasnan, descending); //arg.resize(k); utils::apply_permutation_in_place(key, arg); utils::apply_permutation_in_place(val, arg); } template - void keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan) + void keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) { if (k == 0) return; - std::vector arg = argsort(key, arrsize, hasnan, false); + std::vector arg = argsort(key, arrsize, hasnan, descending); //arg.resize(k); utils::apply_permutation_in_place(key, arg); diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index 635dd6e3..2917ae59 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -36,19 +36,19 @@ #define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ template <> \ - void keyvalue_qsort(type1 *key, type2 *val, size_t arrsize, bool hasnan) \ + void keyvalue_qsort(type1 *key, type2 *val, size_t arrsize, bool hasnan, bool descending) \ { \ - x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan); \ + x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan, descending); \ } \ template <> \ - void keyvalue_select(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan) \ + void keyvalue_select(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - x86simdsortStatic::keyvalue_select(key, val, k, arrsize, hasnan); \ + x86simdsortStatic::keyvalue_select(key, val, k, arrsize, hasnan, descending); \ } \ template <> \ - void keyvalue_partial_sort(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan) \ + void keyvalue_partial_sort(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ - x86simdsortStatic::keyvalue_partial_sort(key, val, k, arrsize, hasnan); \ + x86simdsortStatic::keyvalue_partial_sort(key, val, k, arrsize, hasnan, descending); \ } #define DEFINE_KEYVALUE_METHODS(type) \ diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index df7d2cc2..64885978 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -131,13 +131,13 @@ namespace x86simdsort { #define DISPATCH_KEYVALUE_SORT(TYPE1, TYPE2, ISA) \ static void(CAT(CAT(*internal_kv_qsort_, TYPE1), TYPE2))( \ - TYPE1 *, TYPE2 *, size_t, bool) \ + TYPE1 *, TYPE2 *, size_t, bool, bool) \ = NULL; \ template <> \ - void keyvalue_qsort(TYPE1 *key, TYPE2 *val, size_t arrsize, bool hasnan) \ + void keyvalue_qsort(TYPE1 *key, TYPE2 *val, size_t arrsize, bool hasnan, bool descending) \ { \ (CAT(CAT(*internal_kv_qsort_, TYPE1), TYPE2))( \ - key, val, arrsize, hasnan); \ + key, val, arrsize, hasnan, descending); \ } \ static __attribute__((constructor)) void CAT( \ CAT(resolve_keyvalue_qsort_, TYPE1), TYPE2)(void) \ @@ -162,13 +162,13 @@ namespace x86simdsort { } \ }\ static void(CAT(CAT(*internal_kv_select_, TYPE1), TYPE2))( \ - TYPE1 *, TYPE2 *, size_t, size_t, bool) \ + TYPE1 *, TYPE2 *, size_t, size_t, bool, bool) \ = NULL; \ template <> \ - void keyvalue_select(TYPE1 *key, TYPE2 *val, size_t k, size_t arrsize, bool hasnan) \ + void keyvalue_select(TYPE1 *key, TYPE2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ (CAT(CAT(*internal_kv_select_, TYPE1), TYPE2))( \ - key, val, k, arrsize, hasnan); \ + key, val, k, arrsize, hasnan, descending); \ } \ static __attribute__((constructor)) void CAT( \ CAT(resolve_keyvalue_select_, TYPE1), TYPE2)(void) \ @@ -193,13 +193,13 @@ namespace x86simdsort { } \ } \ static void(CAT(CAT(*internal_kv_partial_sort_, TYPE1), TYPE2))( \ - TYPE1 *, TYPE2 *, size_t, size_t, bool) \ + TYPE1 *, TYPE2 *, size_t, size_t, bool, bool) \ = NULL; \ template <> \ - void keyvalue_partial_sort(TYPE1 *key, TYPE2 *val, size_t k, size_t arrsize, bool hasnan) \ + void keyvalue_partial_sort(TYPE1 *key, TYPE2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \ { \ (CAT(CAT(*internal_kv_partial_sort_, TYPE1), TYPE2))( \ - key, val, k, arrsize, hasnan); \ + key, val, k, arrsize, hasnan, descending); \ } \ static __attribute__((constructor)) void CAT( \ CAT(resolve_keyvalue_partial_sort_, TYPE1), TYPE2)(void) \ diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index f2e09269..8d924e97 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -46,17 +46,17 @@ argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); // keyvalue sort template XSS_EXPORT_SYMBOL void -keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false); +keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false, bool descending = false); // keyvalue select template XSS_EXPORT_SYMBOL void -keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false); +keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // keyvalue partial sort template XSS_EXPORT_SYMBOL void -keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false); +keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // sort an object template diff --git a/src/x86simdsort-static-incl.h b/src/x86simdsort-static-incl.h index 9f1973cc..21c29a67 100644 --- a/src/x86simdsort-static-incl.h +++ b/src/x86simdsort-static-incl.h @@ -47,15 +47,15 @@ argselect(T *arr, size_t *arg, size_t k, size_t size, bool hasnan = false); template X86_SIMD_SORT_FINLINE void -keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false); +keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false, bool descending = false); template X86_SIMD_SORT_FINLINE void -keyvalue_select(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = false); +keyvalue_select(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = false, bool descending = false); template X86_SIMD_SORT_FINLINE void -keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = false); +keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = false, bool descending = false); } // namespace x86simdsortStatic @@ -111,21 +111,21 @@ keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = fal } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::keyvalue_qsort( \ - T1 *key, T2 *val, size_t size, bool hasnan) \ + T1 *key, T2 *val, size_t size, bool hasnan, bool descending) \ { \ - ISA##_qsort_kv(key, val, size, hasnan); \ + ISA##_qsort_kv(key, val, size, hasnan, descending); \ } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::keyvalue_select( \ - T1 *key, T2 *val, size_t k, size_t size, bool hasnan) \ + T1 *key, T2 *val, size_t k, size_t size, bool hasnan, bool descending) \ { \ - ISA##_select_kv(key, val, k, size, hasnan); \ + ISA##_select_kv(key, val, k, size, hasnan, descending); \ } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::keyvalue_partial_sort( \ - T1 *key, T2 *val, size_t k, size_t size, bool hasnan) \ + T1 *key, T2 *val, size_t k, size_t size, bool hasnan, bool descending) \ { \ - ISA##_partial_sort_kv(key, val, k, size, hasnan); \ + ISA##_partial_sort_kv(key, val, k, size, hasnan, descending); \ } /* diff --git a/src/xss-common-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp index 94ee6aca..2ec5614a 100644 --- a/src/xss-common-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -435,6 +435,7 @@ X86_SIMD_SORT_INLINE void kvselect_(type1_t *keys, type1_t biggest = vtype1::type_min(); arrsize_t pivot_index = kvpartition_unrolled( keys, indexes, left, right + 1, pivot, &smallest, &biggest); + if ((pivot != smallest) && (pos < pivot_index)) { kvselect_( keys, indexes, pos, left, pivot_index - 1, max_iters - 1); @@ -452,7 +453,7 @@ template typename half_vector> X86_SIMD_SORT_INLINE void -xss_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan) +xss_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan, bool descending) { using keytype = typename std::conditional(keys, indexes, 0, arrsize - 1, maxiters); replace_inf_with_nan(keys, arrsize, nan_count); + + if (descending) { + std::reverse(keys, keys + arrsize); + std::reverse(indexes, indexes + arrsize); + } } } template typename full_vector, template typename half_vector> X86_SIMD_SORT_INLINE void -xss_select_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +xss_select_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending) { using keytype = typename std::conditional) { arrsize_t nan_count = 0; if (UNLIKELY(hasnan)) { @@ -526,63 +537,68 @@ xss_select_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan UNUSED(hasnan); kvselect_(keys, indexes, k, 0, arrsize - 1, maxiters); } + + if (descending) { + std::reverse(keys, keys + arrsize); + std::reverse(indexes, indexes + arrsize); + } } } template typename full_vector, template typename half_vector> X86_SIMD_SORT_INLINE void -xss_partial_sort_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +xss_partial_sort_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending) { if (k == 0) return; - xss_select_kv(keys, indexes, k - 1, arrsize, hasnan); - xss_qsort_kv(keys, indexes, k - 1, hasnan); + xss_select_kv(keys, indexes, k - 1, arrsize, hasnan, descending); + xss_qsort_kv(keys, indexes, k - 1, hasnan, descending); } template X86_SIMD_SORT_INLINE void -avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false) +avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false, bool descending = false) { xss_qsort_kv( - keys, indexes, arrsize, hasnan); + keys, indexes, arrsize, hasnan, descending); } template X86_SIMD_SORT_INLINE void -avx2_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false) +avx2_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false, bool descending = false) { xss_qsort_kv( - keys, indexes, arrsize, hasnan); + keys, indexes, arrsize, hasnan, descending); } template X86_SIMD_SORT_INLINE void -avx512_select_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +avx512_select_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false, bool descending = false) { xss_select_kv( - keys, indexes, k, arrsize, hasnan); + keys, indexes, k, arrsize, hasnan, descending); } template X86_SIMD_SORT_INLINE void -avx2_select_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +avx2_select_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false, bool descending = false) { xss_select_kv( - keys, indexes, k, arrsize, hasnan); + keys, indexes, k, arrsize, hasnan, descending); } template X86_SIMD_SORT_INLINE void -avx512_partial_sort_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +avx512_partial_sort_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false, bool descending = false) { xss_partial_sort_kv( - keys, indexes, k, arrsize, hasnan); + keys, indexes, k, arrsize, hasnan, descending); } template X86_SIMD_SORT_INLINE void -avx2_partial_sort_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +avx2_partial_sort_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false, bool descending = false) { xss_partial_sort_kv( - keys, indexes, k, arrsize, hasnan); + keys, indexes, k, arrsize, hasnan, descending); } #endif // AVX512_QSORT_64BIT_KV diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 2d5b4ea1..64011941 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -672,6 +672,7 @@ template X86_SIMD_SORT_INLINE void xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) { + if (k == 0) return; xss_qselect(arr, k - 1, arrsize, hasnan); xss_qsort(arr, k - 1, hasnan); } diff --git a/tests/test-keyvalue.cpp b/tests/test-keyvalue.cpp index 51e17591..be2e43d6 100644 --- a/tests/test-keyvalue.cpp +++ b/tests/test-keyvalue.cpp @@ -81,7 +81,7 @@ bool kv_equivalent(T1* keys_comp, T2* vals_comp, T1* keys_ref, T2* vals_ref, siz return true; } -TYPED_TEST_P(simdkvsort, test_kvsort) +TYPED_TEST_P(simdkvsort, test_kvsort_ascending) { using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; @@ -92,9 +92,9 @@ TYPED_TEST_P(simdkvsort, test_kvsort) std::vector val = get_array(type, size); std::vector key_bckp = key; std::vector val_bckp = val; - x86simdsort::keyvalue_qsort(key.data(), val.data(), size, hasnan); + x86simdsort::keyvalue_qsort(key.data(), val.data(), size, hasnan, false); xss::scalar::keyvalue_qsort( - key_bckp.data(), val_bckp.data(), size, hasnan); + key_bckp.data(), val_bckp.data(), size, hasnan, false); bool is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size); ASSERT_EQ(is_kv_equivalent, true); @@ -107,7 +107,33 @@ TYPED_TEST_P(simdkvsort, test_kvsort) } } -TYPED_TEST_P(simdkvsort, test_kvselect) +TYPED_TEST_P(simdkvsort, test_kvsort_descending) +{ + using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; + using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; + for (auto type : this->arrtype) { + bool hasnan = (type == "rand_with_nan") ? true : false; + for (auto size : this->arrsize) { + std::vector key = get_array(type, size); + std::vector val = get_array(type, size); + std::vector key_bckp = key; + std::vector val_bckp = val; + x86simdsort::keyvalue_qsort(key.data(), val.data(), size, hasnan, true); + xss::scalar::keyvalue_qsort( + key_bckp.data(), val_bckp.data(), size, hasnan, true); + + bool is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size); + ASSERT_EQ(is_kv_equivalent, true); + + key.clear(); + val.clear(); + key_bckp.clear(); + val_bckp.clear(); + } + } +} + +TYPED_TEST_P(simdkvsort, test_kvselect_ascending) { using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; @@ -122,13 +148,77 @@ TYPED_TEST_P(simdkvsort, test_kvselect) std::vector val_bckp = val; xss::scalar::keyvalue_qsort( - key_bckp.data(), val_bckp.data(), size, hasnan); + key_bckp.data(), val_bckp.data(), size, hasnan, false); // Test select by using it as part of partial_sort - x86simdsort::keyvalue_select(key.data(), val.data(), k, size, hasnan); + x86simdsort::keyvalue_select(key.data(), val.data(), k, size, hasnan, false); IS_ARR_PARTITIONED(key, k, key_bckp[k], type); - xss::scalar::keyvalue_qsort(key.data(), val.data(), k, hasnan); + xss::scalar::keyvalue_qsort(key.data(), val.data(), k, hasnan, false); + + + bool is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), k); + ASSERT_EQ(is_kv_equivalent, true); + key.clear(); + val.clear(); + key_bckp.clear(); + val_bckp.clear(); + } + } +} + +TYPED_TEST_P(simdkvsort, test_kvselect_descending) +{ + using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; + using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; + for (auto type : this->arrtype) { + bool hasnan = (type == "rand_with_nan") ? true : false; + for (auto size : this->arrsize) { + size_t k = rand() % size; + + std::vector key = get_array(type, size); + std::vector val = get_array(type, size); + std::vector key_bckp = key; + std::vector val_bckp = val; + + xss::scalar::keyvalue_qsort( + key_bckp.data(), val_bckp.data(), size, hasnan, true); + + // Test select by using it as part of partial_sort + x86simdsort::keyvalue_select(key.data(), val.data(), k, size, hasnan, true); + IS_ARR_PARTITIONED(key, k, key_bckp[k], type, true); + xss::scalar::keyvalue_qsort(key.data(), val.data(), k, hasnan, true); + + + bool is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), k); + ASSERT_EQ(is_kv_equivalent, true); + + key.clear(); + val.clear(); + key_bckp.clear(); + val_bckp.clear(); + } + } +} + +TYPED_TEST_P(simdkvsort, test_kvpartial_sort_ascending) +{ + using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; + using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; + for (auto type : this->arrtype) { + bool hasnan = (type == "rand_with_nan") ? true : false; + for (auto size : this->arrsize) { + size_t k = rand() % size; + + std::vector key = get_array(type, size); + std::vector val = get_array(type, size); + std::vector key_bckp = key; + std::vector val_bckp = val; + x86simdsort::keyvalue_partial_sort(key.data(), val.data(), k, size, hasnan, false); + xss::scalar::keyvalue_qsort( + key_bckp.data(), val_bckp.data(), size, hasnan, false); + + IS_ARR_PARTIALSORTED(key, k, key_bckp, type); bool is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), k); ASSERT_EQ(is_kv_equivalent, true); @@ -141,7 +231,7 @@ TYPED_TEST_P(simdkvsort, test_kvselect) } } -TYPED_TEST_P(simdkvsort, test_kvpartial_sort) +TYPED_TEST_P(simdkvsort, test_kvpartial_sort_descending) { using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; @@ -154,9 +244,9 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort) std::vector val = get_array(type, size); std::vector key_bckp = key; std::vector val_bckp = val; - x86simdsort::keyvalue_partial_sort(key.data(), val.data(), k, size, hasnan); + x86simdsort::keyvalue_partial_sort(key.data(), val.data(), k, size, hasnan, true); xss::scalar::keyvalue_qsort( - key_bckp.data(), val_bckp.data(), size, hasnan); + key_bckp.data(), val_bckp.data(), size, hasnan, true); IS_ARR_PARTIALSORTED(key, k, key_bckp, type); @@ -217,7 +307,7 @@ TYPED_TEST_P(simdkvsort, test_validator) ASSERT_EQ(is_kv_equivalent, true); } -REGISTER_TYPED_TEST_SUITE_P(simdkvsort, test_kvsort, test_kvselect, test_kvpartial_sort, test_validator); +REGISTER_TYPED_TEST_SUITE_P(simdkvsort, test_kvsort_ascending, test_kvsort_descending, test_kvselect_ascending, test_kvselect_descending, test_kvpartial_sort_ascending, test_kvpartial_sort_descending, test_validator); #define CREATE_TUPLES(type) \ std::tuple, std::tuple, \ From 9c1b20f70b85e8430b3ce4240336601bc015997b Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 7 May 2024 15:18:30 -0700 Subject: [PATCH 5/8] Fixed testing logic --- lib/x86simdsort-internal.h | 14 ++--- tests/test-keyvalue.cpp | 103 +++++++++++++++++++++++++++++-------- 2 files changed, 89 insertions(+), 28 deletions(-) diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index 8ef6814c..8ef4066a 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -23,7 +23,7 @@ namespace avx512 { bool descending = false); // key-value select template - XSS_EXPORT_SYMBOL void + XSS_HIDE_SYMBOL void keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // partial sort template @@ -34,7 +34,7 @@ namespace avx512 { bool descending = false); // key-value partial sort template - XSS_EXPORT_SYMBOL void + XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // argsort template @@ -55,7 +55,7 @@ namespace avx2 { // key-value quicksort template XSS_HIDE_SYMBOL void - keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false); + keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false, bool descending = false); // quickselect template XSS_HIDE_SYMBOL void qselect(T *arr, @@ -65,7 +65,7 @@ namespace avx2 { bool descending = false); // key-value select template - XSS_EXPORT_SYMBOL void + XSS_HIDE_SYMBOL void keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // partial sort template @@ -76,7 +76,7 @@ namespace avx2 { bool descending = false); // key-value partial sort template - XSS_EXPORT_SYMBOL void + XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // argsort template @@ -107,7 +107,7 @@ namespace scalar { bool descending = false); // key-value select template - XSS_EXPORT_SYMBOL void + XSS_HIDE_SYMBOL void keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // partial sort template @@ -118,7 +118,7 @@ namespace scalar { bool descending = false); // key-value partial sort template - XSS_EXPORT_SYMBOL void + XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); // argsort template diff --git a/tests/test-keyvalue.cpp b/tests/test-keyvalue.cpp index be2e43d6..919f9134 100644 --- a/tests/test-keyvalue.cpp +++ b/tests/test-keyvalue.cpp @@ -32,10 +32,9 @@ TYPED_TEST_SUITE_P(simdkvsort); template bool same_values(T* v1, T* v2, size_t size){ - // Checks that the values are the same except (maybe) their ordering + // Checks that the values are the same except ordering auto cmp_eq = compare>(); - // TODO hardcoding hasnan to true doesn't break anything right? x86simdsort::qsort(v1, size, true); x86simdsort::qsort(v2, size, true); @@ -49,7 +48,7 @@ bool same_values(T* v1, T* v2, size_t size){ } template -bool kv_equivalent(T1* keys_comp, T2* vals_comp, T1* keys_ref, T2* vals_ref, size_t size){ +bool is_kv_sorted(T1* keys_comp, T2* vals_comp, T1* keys_ref, T2* vals_ref, size_t size){ auto cmp_eq = compare>(); // First check keys are exactly identical @@ -66,7 +65,7 @@ bool kv_equivalent(T1* keys_comp, T2* vals_comp, T1* keys_ref, T2* vals_ref, siz size_t i = 0; for (; i < size; i++){ if (!cmp_eq(keys_comp[i], key_start)){ - // Check that every value in + // Check that every value in this block of constant keys if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)){ return false; @@ -78,6 +77,66 @@ bool kv_equivalent(T1* keys_comp, T2* vals_comp, T1* keys_ref, T2* vals_ref, siz } } + // Handle the last group + if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)){ + return false; + } + + return true; +} + +template +bool is_kv_partialsorted(T1* keys_comp, T2* vals_comp, T1* keys_ref, T2* vals_ref, size_t size, size_t k){ + auto cmp_eq = compare>(); + + // First check keys are exactly identical (up to k) + for (size_t i = 0; i < k; i++){ + if (!cmp_eq(keys_comp[i], keys_ref[i])){ + return false; + } + } + + size_t i_start = 0; + T1 key_start = keys_comp[0]; + // Loop through all identical keys in a block, then compare the sets of values to make sure they are identical + for (size_t i = 0; i < k; i++){ + if (!cmp_eq(keys_comp[i], key_start)){ + // Check that every value in this block of constant keys + + if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)){ + return false; + } + + // Now setup the start variables to begin gathering keys for the next group + i_start = i; + key_start = keys_comp[i]; + } + } + + // Now, we need to do some more work to handle keys exactly equal to the true kth + // First, fully kvsort both arrays + xss::scalar::keyvalue_qsort(keys_ref, vals_ref, size, true, false); + xss::scalar::keyvalue_qsort(keys_comp, vals_comp, size, true, false); + + auto trueKth = keys_ref[k]; + bool notFoundFirst = true; + size_t i = 0; + + for (; i < size; i++){ + if (notFoundFirst && cmp_eq(keys_ref[i], trueKth)){ + notFoundFirst = false; + i_start = i; + }else if (!notFoundFirst && !cmp_eq(keys_ref[i], trueKth)){ + break; + } + } + + if (notFoundFirst) return false; + + if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)){ + return false; + } + return true; } @@ -96,8 +155,8 @@ TYPED_TEST_P(simdkvsort, test_kvsort_ascending) xss::scalar::keyvalue_qsort( key_bckp.data(), val_bckp.data(), size, hasnan, false); - bool is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size); - ASSERT_EQ(is_kv_equivalent, true); + bool is_kv_sorted_ = is_kv_sorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size); + ASSERT_EQ(is_kv_sorted_, true); key.clear(); val.clear(); @@ -122,8 +181,8 @@ TYPED_TEST_P(simdkvsort, test_kvsort_descending) xss::scalar::keyvalue_qsort( key_bckp.data(), val_bckp.data(), size, hasnan, true); - bool is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size); - ASSERT_EQ(is_kv_equivalent, true); + bool is_kv_sorted_ = is_kv_sorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size); + ASSERT_EQ(is_kv_sorted_, true); key.clear(); val.clear(); @@ -155,9 +214,10 @@ TYPED_TEST_P(simdkvsort, test_kvselect_ascending) IS_ARR_PARTITIONED(key, k, key_bckp[k], type); xss::scalar::keyvalue_qsort(key.data(), val.data(), k, hasnan, false); + ASSERT_EQ(key[k], key_bckp[k]); - bool is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), k); - ASSERT_EQ(is_kv_equivalent, true); + bool is_kv_partialsorted_ = is_kv_partialsorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size, k); + ASSERT_EQ(is_kv_partialsorted_, true); key.clear(); val.clear(); @@ -189,9 +249,10 @@ TYPED_TEST_P(simdkvsort, test_kvselect_descending) IS_ARR_PARTITIONED(key, k, key_bckp[k], type, true); xss::scalar::keyvalue_qsort(key.data(), val.data(), k, hasnan, true); + ASSERT_EQ(key[k], key_bckp[k]); - bool is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), k); - ASSERT_EQ(is_kv_equivalent, true); + bool is_kv_partialsorted_ = is_kv_partialsorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size, k); + ASSERT_EQ(is_kv_partialsorted_, true); key.clear(); val.clear(); @@ -220,8 +281,8 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_ascending) IS_ARR_PARTIALSORTED(key, k, key_bckp, type); - bool is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), k); - ASSERT_EQ(is_kv_equivalent, true); + bool is_kv_partialsorted_ = is_kv_partialsorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size, k); + ASSERT_EQ(is_kv_partialsorted_, true); key.clear(); val.clear(); @@ -250,8 +311,8 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_descending) IS_ARR_PARTIALSORTED(key, k, key_bckp, type); - bool is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), k); - ASSERT_EQ(is_kv_equivalent, true); + bool is_kv_partialsorted_ = is_kv_partialsorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size, k); + ASSERT_EQ(is_kv_partialsorted_, true); key.clear(); val.clear(); @@ -275,26 +336,26 @@ TYPED_TEST_P(simdkvsort, test_validator) std::vector val_bckp = val; // Duplicate keys, but otherwise exactly identical - is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); + is_kv_equivalent = is_kv_sorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); ASSERT_EQ(is_kv_equivalent, true); val = {2,1,4,3}; // Now values are backwards, but this is still fine - is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); + is_kv_equivalent = is_kv_sorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); ASSERT_EQ(is_kv_equivalent, true); val = {1,3,2,4}; // Now values are mixed up, should fail - is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); + is_kv_equivalent = is_kv_sorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); ASSERT_EQ(is_kv_equivalent, false); val = {1,2,3,4}; key = {0,0,0,0}; // Now keys are messed up, should fail - is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); + is_kv_equivalent = is_kv_sorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); ASSERT_EQ(is_kv_equivalent, false); key = {0,0,0,0,0,0}; @@ -303,7 +364,7 @@ TYPED_TEST_P(simdkvsort, test_validator) val = {4,3,1,6,5,2}; // All keys identical, simply reordered values - is_kv_equivalent = kv_equivalent(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); + is_kv_equivalent = is_kv_sorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); ASSERT_EQ(is_kv_equivalent, true); } From 56623b16e1865f96d16374f8da371fb26655b5ac Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Tue, 7 May 2024 16:34:52 -0700 Subject: [PATCH 6/8] README updates and formatting --- README.md | 4 +- benchmarks/bench-keyvalue.hpp | 3 +- lib/x86simdsort-avx2.cpp | 31 +++- lib/x86simdsort-internal.h | 69 ++++++-- lib/x86simdsort-scalar.h | 21 ++- lib/x86simdsort-skx.cpp | 31 +++- lib/x86simdsort.cpp | 22 ++- lib/x86simdsort.h | 23 ++- src/x86simdsort-static-incl.h | 37 ++++- src/xss-common-keyvaluesort.hpp | 116 ++++++++----- tests/test-keyvalue.cpp | 284 ++++++++++++++++++++------------ 11 files changed, 442 insertions(+), 199 deletions(-) diff --git a/README.md b/README.md index 308edfde..99e8431f 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,9 @@ int32_t, double, uint64_t, int64_t]` ## Key-value sort routines on pairs of arrays ```cpp -void x86simdsort::keyvalue_qsort(T1* key, T2* val, size_t size, bool hasnan); +void x86simdsort::keyvalue_qsort(T1* key, T2* val, size_t size, bool hasnan, bool descending); +void x86simdsort::keyvalue_select(T1* key, T2* val, size_t k, size_t size, bool hasnan, bool descending); +void x86simdsort::keyvalue_partial_sort(T1* key, T2* val, size_t k, size_t size, bool hasnan, bool descending); ``` Supported datatypes: `T1`, `T2` $\in$ `[float, uint32_t, int32_t, double, uint64_t, int64_t]` Note that keyvalue sort is not yet supported for 16-bit diff --git a/benchmarks/bench-keyvalue.hpp b/benchmarks/bench-keyvalue.hpp index 5ed8d48a..e021bdf5 100644 --- a/benchmarks/bench-keyvalue.hpp +++ b/benchmarks/bench-keyvalue.hpp @@ -13,7 +13,8 @@ static void scalarkvsort(benchmark::State &state, Args &&...args) std::vector key_bkp = key; // benchmark for (auto _ : state) { - xss::scalar::keyvalue_qsort(key.data(), val.data(), arrsize, false, false); + xss::scalar::keyvalue_qsort( + key.data(), val.data(), arrsize, false, false); state.PauseTiming(); key = key_bkp; state.ResumeTiming(); diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp index 9bbef8cd..c00591e4 100644 --- a/lib/x86simdsort-avx2.cpp +++ b/lib/x86simdsort-avx2.cpp @@ -36,21 +36,38 @@ #define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ template <> \ - void keyvalue_qsort(type1 *key, type2 *val, size_t arrsize, bool hasnan, bool descending) \ + void keyvalue_qsort(type1 *key, \ + type2 *val, \ + size_t arrsize, \ + bool hasnan, \ + bool descending) \ { \ - x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan, descending); \ + x86simdsortStatic::keyvalue_qsort( \ + key, val, arrsize, hasnan, descending); \ } \ template <> \ - void keyvalue_select(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void keyvalue_select(type1 *key, \ + type2 *val, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending) \ { \ - x86simdsortStatic::keyvalue_select(key, val, k, arrsize, hasnan, descending); \ + x86simdsortStatic::keyvalue_select( \ + key, val, k, arrsize, hasnan, descending); \ } \ template <> \ - void keyvalue_partial_sort(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void keyvalue_partial_sort(type1 *key, \ + type2 *val, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending) \ { \ - x86simdsortStatic::keyvalue_partial_sort(key, val, k, arrsize, hasnan, descending); \ + x86simdsortStatic::keyvalue_partial_sort( \ + key, val, k, arrsize, hasnan, descending); \ } - + #define DEFINE_KEYVALUE_METHODS(type) \ DEFINE_KEYVALUE_METHODS_BASE(type, uint64_t) \ DEFINE_KEYVALUE_METHODS_BASE(type, int64_t) \ diff --git a/lib/x86simdsort-internal.h b/lib/x86simdsort-internal.h index 8ef4066a..6cf261a9 100644 --- a/lib/x86simdsort-internal.h +++ b/lib/x86simdsort-internal.h @@ -12,8 +12,11 @@ namespace avx512 { qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // key-value quicksort template - XSS_HIDE_SYMBOL void - keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key, + T2 *val, + size_t arrsize, + bool hasnan = false, + bool descending = false); // quickselect template XSS_HIDE_SYMBOL void qselect(T *arr, @@ -23,8 +26,12 @@ namespace avx512 { bool descending = false); // key-value select template - XSS_HIDE_SYMBOL void - keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void keyvalue_select(T1 *key, + T2 *val, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // partial sort template XSS_HIDE_SYMBOL void partial_qsort(T *arr, @@ -34,8 +41,12 @@ namespace avx512 { bool descending = false); // key-value partial sort template - XSS_HIDE_SYMBOL void - keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key, + T2 *val, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector argsort(T *arr, @@ -54,8 +65,11 @@ namespace avx2 { qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // key-value quicksort template - XSS_HIDE_SYMBOL void - keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key, + T2 *val, + size_t arrsize, + bool hasnan = false, + bool descending = false); // quickselect template XSS_HIDE_SYMBOL void qselect(T *arr, @@ -65,8 +79,12 @@ namespace avx2 { bool descending = false); // key-value select template - XSS_HIDE_SYMBOL void - keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void keyvalue_select(T1 *key, + T2 *val, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // partial sort template XSS_HIDE_SYMBOL void partial_qsort(T *arr, @@ -76,8 +94,12 @@ namespace avx2 { bool descending = false); // key-value partial sort template - XSS_HIDE_SYMBOL void - keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key, + T2 *val, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector argsort(T *arr, @@ -96,8 +118,11 @@ namespace scalar { qsort(T *arr, size_t arrsize, bool hasnan = false, bool descending = false); // key-value quicksort template - XSS_HIDE_SYMBOL void - keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void keyvalue_qsort(T1 *key, + T2 *val, + size_t arrsize, + bool hasnan = false, + bool descending = false); // quickselect template XSS_HIDE_SYMBOL void qselect(T *arr, @@ -107,8 +132,12 @@ namespace scalar { bool descending = false); // key-value select template - XSS_HIDE_SYMBOL void - keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void keyvalue_select(T1 *key, + T2 *val, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // partial sort template XSS_HIDE_SYMBOL void partial_qsort(T *arr, @@ -118,8 +147,12 @@ namespace scalar { bool descending = false); // key-value partial sort template - XSS_HIDE_SYMBOL void - keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); + XSS_HIDE_SYMBOL void keyvalue_partial_sort(T1 *key, + T2 *val, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // argsort template XSS_HIDE_SYMBOL std::vector argsort(T *arr, diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index ce383f5a..dadd1eef 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -100,30 +100,41 @@ namespace scalar { return arg; } template - void keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan, bool descending) + void keyvalue_qsort( + T1 *key, T2 *val, size_t arrsize, bool hasnan, bool descending) { std::vector arg = argsort(key, arrsize, hasnan, descending); utils::apply_permutation_in_place(key, arg); utils::apply_permutation_in_place(val, arg); } template - void keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) + void keyvalue_select(T1 *key, + T2 *val, + size_t k, + size_t arrsize, + bool hasnan, + bool descending) { if (k == 0) return; // Note that this does a full partial sort, not just a select std::vector arg = argsort(key, arrsize, hasnan, descending); //arg.resize(k); - + utils::apply_permutation_in_place(key, arg); utils::apply_permutation_in_place(val, arg); } template - void keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) + void keyvalue_partial_sort(T1 *key, + T2 *val, + size_t k, + size_t arrsize, + bool hasnan, + bool descending) { if (k == 0) return; std::vector arg = argsort(key, arrsize, hasnan, descending); //arg.resize(k); - + utils::apply_permutation_in_place(key, arg); utils::apply_permutation_in_place(val, arg); } diff --git a/lib/x86simdsort-skx.cpp b/lib/x86simdsort-skx.cpp index 2917ae59..7d9d5aa4 100644 --- a/lib/x86simdsort-skx.cpp +++ b/lib/x86simdsort-skx.cpp @@ -36,21 +36,38 @@ #define DEFINE_KEYVALUE_METHODS_BASE(type1, type2) \ template <> \ - void keyvalue_qsort(type1 *key, type2 *val, size_t arrsize, bool hasnan, bool descending) \ + void keyvalue_qsort(type1 *key, \ + type2 *val, \ + size_t arrsize, \ + bool hasnan, \ + bool descending) \ { \ - x86simdsortStatic::keyvalue_qsort(key, val, arrsize, hasnan, descending); \ + x86simdsortStatic::keyvalue_qsort( \ + key, val, arrsize, hasnan, descending); \ } \ template <> \ - void keyvalue_select(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void keyvalue_select(type1 *key, \ + type2 *val, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending) \ { \ - x86simdsortStatic::keyvalue_select(key, val, k, arrsize, hasnan, descending); \ + x86simdsortStatic::keyvalue_select( \ + key, val, k, arrsize, hasnan, descending); \ } \ template <> \ - void keyvalue_partial_sort(type1 *key, type2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void keyvalue_partial_sort(type1 *key, \ + type2 *val, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending) \ { \ - x86simdsortStatic::keyvalue_partial_sort(key, val, k, arrsize, hasnan, descending); \ + x86simdsortStatic::keyvalue_partial_sort( \ + key, val, k, arrsize, hasnan, descending); \ } - + #define DEFINE_KEYVALUE_METHODS(type) \ DEFINE_KEYVALUE_METHODS_BASE(type, uint64_t) \ DEFINE_KEYVALUE_METHODS_BASE(type, int64_t) \ diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 64885978..e01a86f1 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -134,7 +134,11 @@ namespace x86simdsort { TYPE1 *, TYPE2 *, size_t, bool, bool) \ = NULL; \ template <> \ - void keyvalue_qsort(TYPE1 *key, TYPE2 *val, size_t arrsize, bool hasnan, bool descending) \ + void keyvalue_qsort(TYPE1 *key, \ + TYPE2 *val, \ + size_t arrsize, \ + bool hasnan, \ + bool descending) \ { \ (CAT(CAT(*internal_kv_qsort_, TYPE1), TYPE2))( \ key, val, arrsize, hasnan, descending); \ @@ -160,12 +164,17 @@ namespace x86simdsort { return; \ } \ } \ - }\ + } \ static void(CAT(CAT(*internal_kv_select_, TYPE1), TYPE2))( \ TYPE1 *, TYPE2 *, size_t, size_t, bool, bool) \ = NULL; \ template <> \ - void keyvalue_select(TYPE1 *key, TYPE2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void keyvalue_select(TYPE1 *key, \ + TYPE2 *val, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending) \ { \ (CAT(CAT(*internal_kv_select_, TYPE1), TYPE2))( \ key, val, k, arrsize, hasnan, descending); \ @@ -196,7 +205,12 @@ namespace x86simdsort { TYPE1 *, TYPE2 *, size_t, size_t, bool, bool) \ = NULL; \ template <> \ - void keyvalue_partial_sort(TYPE1 *key, TYPE2 *val, size_t k, size_t arrsize, bool hasnan, bool descending) \ + void keyvalue_partial_sort(TYPE1 *key, \ + TYPE2 *val, \ + size_t k, \ + size_t arrsize, \ + bool hasnan, \ + bool descending) \ { \ (CAT(CAT(*internal_kv_partial_sort_, TYPE1), TYPE2))( \ key, val, k, arrsize, hasnan, descending); \ diff --git a/lib/x86simdsort.h b/lib/x86simdsort.h index 8d924e97..c79f2648 100644 --- a/lib/x86simdsort.h +++ b/lib/x86simdsort.h @@ -45,18 +45,29 @@ argselect(T *arr, size_t k, size_t arrsize, bool hasnan = false); // keyvalue sort template -XSS_EXPORT_SYMBOL void -keyvalue_qsort(T1 *key, T2 *val, size_t arrsize, bool hasnan = false, bool descending = false); +XSS_EXPORT_SYMBOL void keyvalue_qsort(T1 *key, + T2 *val, + size_t arrsize, + bool hasnan = false, + bool descending = false); // keyvalue select template -XSS_EXPORT_SYMBOL void -keyvalue_select(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); +XSS_EXPORT_SYMBOL void keyvalue_select(T1 *key, + T2 *val, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // keyvalue partial sort template -XSS_EXPORT_SYMBOL void -keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t arrsize, bool hasnan = false, bool descending = false); +XSS_EXPORT_SYMBOL void keyvalue_partial_sort(T1 *key, + T2 *val, + size_t k, + size_t arrsize, + bool hasnan = false, + bool descending = false); // sort an object template diff --git a/src/x86simdsort-static-incl.h b/src/x86simdsort-static-incl.h index 21c29a67..52dde7b3 100644 --- a/src/x86simdsort-static-incl.h +++ b/src/x86simdsort-static-incl.h @@ -46,16 +46,27 @@ void X86_SIMD_SORT_FINLINE argselect(T *arr, size_t *arg, size_t k, size_t size, bool hasnan = false); template -X86_SIMD_SORT_FINLINE void -keyvalue_qsort(T1 *key, T2 *val, size_t size, bool hasnan = false, bool descending = false); +X86_SIMD_SORT_FINLINE void keyvalue_qsort(T1 *key, + T2 *val, + size_t size, + bool hasnan = false, + bool descending = false); template -X86_SIMD_SORT_FINLINE void -keyvalue_select(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = false, bool descending = false); +X86_SIMD_SORT_FINLINE void keyvalue_select(T1 *key, + T2 *val, + size_t k, + size_t size, + bool hasnan = false, + bool descending = false); template -X86_SIMD_SORT_FINLINE void -keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = false, bool descending = false); +X86_SIMD_SORT_FINLINE void keyvalue_partial_sort(T1 *key, + T2 *val, + size_t k, + size_t size, + bool hasnan = false, + bool descending = false); } // namespace x86simdsortStatic @@ -117,13 +128,23 @@ keyvalue_partial_sort(T1 *key, T2 *val, size_t k, size_t size, bool hasnan = fal } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::keyvalue_select( \ - T1 *key, T2 *val, size_t k, size_t size, bool hasnan, bool descending) \ + T1 *key, \ + T2 *val, \ + size_t k, \ + size_t size, \ + bool hasnan, \ + bool descending) \ { \ ISA##_select_kv(key, val, k, size, hasnan, descending); \ } \ template \ X86_SIMD_SORT_FINLINE void x86simdsortStatic::keyvalue_partial_sort( \ - T1 *key, T2 *val, size_t k, size_t size, bool hasnan, bool descending) \ + T1 *key, \ + T2 *val, \ + size_t k, \ + size_t size, \ + bool hasnan, \ + bool descending) \ { \ ISA##_partial_sort_kv(key, val, k, size, hasnan, descending); \ } diff --git a/src/xss-common-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp index 2ec5614a..2615aad8 100644 --- a/src/xss-common-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -406,11 +406,11 @@ template X86_SIMD_SORT_INLINE void kvselect_(type1_t *keys, - type2_t *indexes, - arrsize_t pos, - arrsize_t left, - arrsize_t right, - int max_iters) + type2_t *indexes, + arrsize_t pos, + arrsize_t left, + arrsize_t right, + int max_iters) { /* * Resort to std::sort if quicksort isnt making any progress @@ -435,7 +435,7 @@ X86_SIMD_SORT_INLINE void kvselect_(type1_t *keys, type1_t biggest = vtype1::type_min(); arrsize_t pivot_index = kvpartition_unrolled( keys, indexes, left, right + 1, pivot, &smallest, &biggest); - + if ((pivot != smallest) && (pos < pivot_index)) { kvselect_( keys, indexes, pos, left, pivot_index - 1, max_iters - 1); @@ -452,8 +452,8 @@ template typename half_vector> -X86_SIMD_SORT_INLINE void -xss_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan, bool descending) +X86_SIMD_SORT_INLINE void xss_qsort_kv( + T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan, bool descending) { using keytype = typename std::conditional(keys, indexes, 0, arrsize - 1, maxiters); replace_inf_with_nan(keys, arrsize, nan_count); - + if (descending) { std::reverse(keys, keys + arrsize); std::reverse(indexes, indexes + arrsize); @@ -496,9 +496,18 @@ xss_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan, bool descend } } -template typename full_vector, template typename half_vector> -X86_SIMD_SORT_INLINE void -xss_select_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending) +template + typename full_vector, + template + typename half_vector> +X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys, + T2 *indexes, + arrsize_t k, + arrsize_t arrsize, + bool hasnan, + bool descending) { using keytype = typename std::conditional) { arrsize_t nan_count = 0; if (UNLIKELY(hasnan)) { nan_count = replace_nan_with_inf>(keys, arrsize); } - kvselect_(keys, indexes, k, 0, arrsize - 1, maxiters); + kvselect_( + keys, indexes, k, 0, arrsize - 1, maxiters); replace_inf_with_nan(keys, arrsize, nan_count); } else { UNUSED(hasnan); - kvselect_(keys, indexes, k, 0, arrsize - 1, maxiters); + kvselect_( + keys, indexes, k, 0, arrsize - 1, maxiters); } - + if (descending) { std::reverse(keys, keys + arrsize); std::reverse(indexes, indexes + arrsize); @@ -545,58 +554,91 @@ xss_select_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan } } -template typename full_vector, template typename half_vector> -X86_SIMD_SORT_INLINE void -xss_partial_sort_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan, bool descending) +template + typename full_vector, + template + typename half_vector> +X86_SIMD_SORT_INLINE void xss_partial_sort_kv(T1 *keys, + T2 *indexes, + arrsize_t k, + arrsize_t arrsize, + bool hasnan, + bool descending) { if (k == 0) return; - xss_select_kv(keys, indexes, k - 1, arrsize, hasnan, descending); - xss_qsort_kv(keys, indexes, k - 1, hasnan, descending); + xss_select_kv( + keys, indexes, k - 1, arrsize, hasnan, descending); + xss_qsort_kv( + keys, indexes, k - 1, hasnan, descending); } template -X86_SIMD_SORT_INLINE void -avx512_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false, bool descending = false) +X86_SIMD_SORT_INLINE void avx512_qsort_kv(T1 *keys, + T2 *indexes, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) { xss_qsort_kv( keys, indexes, arrsize, hasnan, descending); } template -X86_SIMD_SORT_INLINE void -avx2_qsort_kv(T1 *keys, T2 *indexes, arrsize_t arrsize, bool hasnan = false, bool descending = false) +X86_SIMD_SORT_INLINE void avx2_qsort_kv(T1 *keys, + T2 *indexes, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) { xss_qsort_kv( keys, indexes, arrsize, hasnan, descending); } template -X86_SIMD_SORT_INLINE void -avx512_select_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false, bool descending = false) +X86_SIMD_SORT_INLINE void avx512_select_kv(T1 *keys, + T2 *indexes, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) { xss_select_kv( keys, indexes, k, arrsize, hasnan, descending); } template -X86_SIMD_SORT_INLINE void -avx2_select_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false, bool descending = false) +X86_SIMD_SORT_INLINE void avx2_select_kv(T1 *keys, + T2 *indexes, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) { xss_select_kv( keys, indexes, k, arrsize, hasnan, descending); } template -X86_SIMD_SORT_INLINE void -avx512_partial_sort_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false, bool descending = false) +X86_SIMD_SORT_INLINE void avx512_partial_sort_kv(T1 *keys, + T2 *indexes, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) { xss_partial_sort_kv( keys, indexes, k, arrsize, hasnan, descending); } template -X86_SIMD_SORT_INLINE void -avx2_partial_sort_kv(T1 *keys, T2 *indexes, arrsize_t k, arrsize_t arrsize, bool hasnan = false, bool descending = false) +X86_SIMD_SORT_INLINE void avx2_partial_sort_kv(T1 *keys, + T2 *indexes, + arrsize_t k, + arrsize_t arrsize, + bool hasnan = false, + bool descending = false) { xss_partial_sort_kv( keys, indexes, k, arrsize, hasnan, descending); diff --git a/tests/test-keyvalue.cpp b/tests/test-keyvalue.cpp index 919f9134..c6ad960e 100644 --- a/tests/test-keyvalue.cpp +++ b/tests/test-keyvalue.cpp @@ -31,112 +31,119 @@ class simdkvsort : public ::testing::Test { TYPED_TEST_SUITE_P(simdkvsort); template -bool same_values(T* v1, T* v2, size_t size){ +bool same_values(T *v1, T *v2, size_t size) +{ // Checks that the values are the same except ordering auto cmp_eq = compare>(); - + x86simdsort::qsort(v1, size, true); x86simdsort::qsort(v2, size, true); - - for (size_t i = 0; i < size; i++){ - if (!cmp_eq(v1[i], v2[i])){ - return false; - } + + for (size_t i = 0; i < size; i++) { + if (!cmp_eq(v1[i], v2[i])) { return false; } } - + return true; } template -bool is_kv_sorted(T1* keys_comp, T2* vals_comp, T1* keys_ref, T2* vals_ref, size_t size){ +bool is_kv_sorted( + T1 *keys_comp, T2 *vals_comp, T1 *keys_ref, T2 *vals_ref, size_t size) +{ auto cmp_eq = compare>(); - + // First check keys are exactly identical - for (size_t i = 0; i < size; i++){ - if (!cmp_eq(keys_comp[i], keys_ref[i])){ - return false; - } + for (size_t i = 0; i < size; i++) { + if (!cmp_eq(keys_comp[i], keys_ref[i])) { return false; } } - + size_t i_start = 0; T1 key_start = keys_comp[0]; // Loop through all identical keys in a block, then compare the sets of values to make sure they are identical // We need the index after the loop size_t i = 0; - for (; i < size; i++){ - if (!cmp_eq(keys_comp[i], key_start)){ + for (; i < size; i++) { + if (!cmp_eq(keys_comp[i], key_start)) { // Check that every value in this block of constant keys - if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)){ + if (!same_values( + vals_ref + i_start, vals_comp + i_start, i - i_start)) { return false; } - + // Now setup the start variables to begin gathering keys for the next group i_start = i; key_start = keys_comp[i]; } } - + // Handle the last group - if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)){ + if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)) { return false; } - + return true; } template -bool is_kv_partialsorted(T1* keys_comp, T2* vals_comp, T1* keys_ref, T2* vals_ref, size_t size, size_t k){ +bool is_kv_partialsorted(T1 *keys_comp, + T2 *vals_comp, + T1 *keys_ref, + T2 *vals_ref, + size_t size, + size_t k) +{ auto cmp_eq = compare>(); - + // First check keys are exactly identical (up to k) - for (size_t i = 0; i < k; i++){ - if (!cmp_eq(keys_comp[i], keys_ref[i])){ - return false; - } + for (size_t i = 0; i < k; i++) { + if (!cmp_eq(keys_comp[i], keys_ref[i])) { return false; } } - + size_t i_start = 0; T1 key_start = keys_comp[0]; // Loop through all identical keys in a block, then compare the sets of values to make sure they are identical - for (size_t i = 0; i < k; i++){ - if (!cmp_eq(keys_comp[i], key_start)){ + for (size_t i = 0; i < k; i++) { + if (!cmp_eq(keys_comp[i], key_start)) { // Check that every value in this block of constant keys - if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)){ + if (!same_values( + vals_ref + i_start, vals_comp + i_start, i - i_start)) { return false; } - + // Now setup the start variables to begin gathering keys for the next group i_start = i; key_start = keys_comp[i]; } } - + // Now, we need to do some more work to handle keys exactly equal to the true kth // First, fully kvsort both arrays xss::scalar::keyvalue_qsort(keys_ref, vals_ref, size, true, false); - xss::scalar::keyvalue_qsort(keys_comp, vals_comp, size, true, false); - + xss::scalar::keyvalue_qsort( + keys_comp, vals_comp, size, true, false); + auto trueKth = keys_ref[k]; bool notFoundFirst = true; size_t i = 0; - - for (; i < size; i++){ - if (notFoundFirst && cmp_eq(keys_ref[i], trueKth)){ + + for (; i < size; i++) { + if (notFoundFirst && cmp_eq(keys_ref[i], trueKth)) { notFoundFirst = false; i_start = i; - }else if (!notFoundFirst && !cmp_eq(keys_ref[i], trueKth)){ + } + else if (!notFoundFirst && !cmp_eq(keys_ref[i], trueKth)) { break; } } if (notFoundFirst) return false; - - if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)){ + + if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)) { return false; } - + return true; } @@ -151,13 +158,18 @@ TYPED_TEST_P(simdkvsort, test_kvsort_ascending) std::vector val = get_array(type, size); std::vector key_bckp = key; std::vector val_bckp = val; - x86simdsort::keyvalue_qsort(key.data(), val.data(), size, hasnan, false); + x86simdsort::keyvalue_qsort( + key.data(), val.data(), size, hasnan, false); xss::scalar::keyvalue_qsort( key_bckp.data(), val_bckp.data(), size, hasnan, false); - - bool is_kv_sorted_ = is_kv_sorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size); + + bool is_kv_sorted_ = is_kv_sorted(key.data(), + val.data(), + key_bckp.data(), + val_bckp.data(), + size); ASSERT_EQ(is_kv_sorted_, true); - + key.clear(); val.clear(); key_bckp.clear(); @@ -177,13 +189,18 @@ TYPED_TEST_P(simdkvsort, test_kvsort_descending) std::vector val = get_array(type, size); std::vector key_bckp = key; std::vector val_bckp = val; - x86simdsort::keyvalue_qsort(key.data(), val.data(), size, hasnan, true); + x86simdsort::keyvalue_qsort( + key.data(), val.data(), size, hasnan, true); xss::scalar::keyvalue_qsort( key_bckp.data(), val_bckp.data(), size, hasnan, true); - - bool is_kv_sorted_ = is_kv_sorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size); + + bool is_kv_sorted_ = is_kv_sorted(key.data(), + val.data(), + key_bckp.data(), + val_bckp.data(), + size); ASSERT_EQ(is_kv_sorted_, true); - + key.clear(); val.clear(); key_bckp.clear(); @@ -200,25 +217,33 @@ TYPED_TEST_P(simdkvsort, test_kvselect_ascending) bool hasnan = (type == "rand_with_nan") ? true : false; for (auto size : this->arrsize) { size_t k = rand() % size; - + std::vector key = get_array(type, size); std::vector val = get_array(type, size); std::vector key_bckp = key; std::vector val_bckp = val; - + xss::scalar::keyvalue_qsort( key_bckp.data(), val_bckp.data(), size, hasnan, false); - + // Test select by using it as part of partial_sort - x86simdsort::keyvalue_select(key.data(), val.data(), k, size, hasnan, false); + x86simdsort::keyvalue_select( + key.data(), val.data(), k, size, hasnan, false); IS_ARR_PARTITIONED(key, k, key_bckp[k], type); - xss::scalar::keyvalue_qsort(key.data(), val.data(), k, hasnan, false); - + xss::scalar::keyvalue_qsort( + key.data(), val.data(), k, hasnan, false); + ASSERT_EQ(key[k], key_bckp[k]); - - bool is_kv_partialsorted_ = is_kv_partialsorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size, k); + + bool is_kv_partialsorted_ + = is_kv_partialsorted(key.data(), + val.data(), + key_bckp.data(), + val_bckp.data(), + size, + k); ASSERT_EQ(is_kv_partialsorted_, true); - + key.clear(); val.clear(); key_bckp.clear(); @@ -235,25 +260,33 @@ TYPED_TEST_P(simdkvsort, test_kvselect_descending) bool hasnan = (type == "rand_with_nan") ? true : false; for (auto size : this->arrsize) { size_t k = rand() % size; - + std::vector key = get_array(type, size); std::vector val = get_array(type, size); std::vector key_bckp = key; std::vector val_bckp = val; - + xss::scalar::keyvalue_qsort( key_bckp.data(), val_bckp.data(), size, hasnan, true); - + // Test select by using it as part of partial_sort - x86simdsort::keyvalue_select(key.data(), val.data(), k, size, hasnan, true); + x86simdsort::keyvalue_select( + key.data(), val.data(), k, size, hasnan, true); IS_ARR_PARTITIONED(key, k, key_bckp[k], type, true); - xss::scalar::keyvalue_qsort(key.data(), val.data(), k, hasnan, true); - + xss::scalar::keyvalue_qsort( + key.data(), val.data(), k, hasnan, true); + ASSERT_EQ(key[k], key_bckp[k]); - - bool is_kv_partialsorted_ = is_kv_partialsorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size, k); + + bool is_kv_partialsorted_ + = is_kv_partialsorted(key.data(), + val.data(), + key_bckp.data(), + val_bckp.data(), + size, + k); ASSERT_EQ(is_kv_partialsorted_, true); - + key.clear(); val.clear(); key_bckp.clear(); @@ -270,20 +303,27 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_ascending) bool hasnan = (type == "rand_with_nan") ? true : false; for (auto size : this->arrsize) { size_t k = rand() % size; - + std::vector key = get_array(type, size); std::vector val = get_array(type, size); std::vector key_bckp = key; std::vector val_bckp = val; - x86simdsort::keyvalue_partial_sort(key.data(), val.data(), k, size, hasnan, false); + x86simdsort::keyvalue_partial_sort( + key.data(), val.data(), k, size, hasnan, false); xss::scalar::keyvalue_qsort( key_bckp.data(), val_bckp.data(), size, hasnan, false); - + IS_ARR_PARTIALSORTED(key, k, key_bckp, type); - - bool is_kv_partialsorted_ = is_kv_partialsorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size, k); + + bool is_kv_partialsorted_ + = is_kv_partialsorted(key.data(), + val.data(), + key_bckp.data(), + val_bckp.data(), + size, + k); ASSERT_EQ(is_kv_partialsorted_, true); - + key.clear(); val.clear(); key_bckp.clear(); @@ -300,20 +340,27 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_descending) bool hasnan = (type == "rand_with_nan") ? true : false; for (auto size : this->arrsize) { size_t k = rand() % size; - + std::vector key = get_array(type, size); std::vector val = get_array(type, size); std::vector key_bckp = key; std::vector val_bckp = val; - x86simdsort::keyvalue_partial_sort(key.data(), val.data(), k, size, hasnan, true); + x86simdsort::keyvalue_partial_sort( + key.data(), val.data(), k, size, hasnan, true); xss::scalar::keyvalue_qsort( key_bckp.data(), val_bckp.data(), size, hasnan, true); - + IS_ARR_PARTIALSORTED(key, k, key_bckp, type); - - bool is_kv_partialsorted_ = is_kv_partialsorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), size, k); + + bool is_kv_partialsorted_ + = is_kv_partialsorted(key.data(), + val.data(), + key_bckp.data(), + val_bckp.data(), + size, + k); ASSERT_EQ(is_kv_partialsorted_, true); - + key.clear(); val.clear(); key_bckp.clear(); @@ -327,48 +374,75 @@ TYPED_TEST_P(simdkvsort, test_validator) // Tests a few edge cases to verify the tests are working correctly and identifying it as functional using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; - + bool is_kv_equivalent; - + std::vector key = {0, 0, 1, 1}; std::vector val = {1, 2, 3, 4}; std::vector key_bckp = key; std::vector val_bckp = val; - + // Duplicate keys, but otherwise exactly identical - is_kv_equivalent = is_kv_sorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); + is_kv_equivalent = is_kv_sorted(key.data(), + val.data(), + key_bckp.data(), + val_bckp.data(), + key.size()); ASSERT_EQ(is_kv_equivalent, true); - - val = {2,1,4,3}; - + + val = {2, 1, 4, 3}; + // Now values are backwards, but this is still fine - is_kv_equivalent = is_kv_sorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); + is_kv_equivalent = is_kv_sorted(key.data(), + val.data(), + key_bckp.data(), + val_bckp.data(), + key.size()); ASSERT_EQ(is_kv_equivalent, true); - - val = {1,3,2,4}; - + + val = {1, 3, 2, 4}; + // Now values are mixed up, should fail - is_kv_equivalent = is_kv_sorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); + is_kv_equivalent = is_kv_sorted(key.data(), + val.data(), + key_bckp.data(), + val_bckp.data(), + key.size()); ASSERT_EQ(is_kv_equivalent, false); - - val = {1,2,3,4}; - key = {0,0,0,0}; - + + val = {1, 2, 3, 4}; + key = {0, 0, 0, 0}; + // Now keys are messed up, should fail - is_kv_equivalent = is_kv_sorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); + is_kv_equivalent = is_kv_sorted(key.data(), + val.data(), + key_bckp.data(), + val_bckp.data(), + key.size()); ASSERT_EQ(is_kv_equivalent, false); - - key = {0,0,0,0,0,0}; + + key = {0, 0, 0, 0, 0, 0}; key_bckp = key; - val_bckp = {1,2,3,4,5,6}; - val = {4,3,1,6,5,2}; - + val_bckp = {1, 2, 3, 4, 5, 6}; + val = {4, 3, 1, 6, 5, 2}; + // All keys identical, simply reordered values - is_kv_equivalent = is_kv_sorted(key.data(), val.data(), key_bckp.data(), val_bckp.data(), key.size()); + is_kv_equivalent = is_kv_sorted(key.data(), + val.data(), + key_bckp.data(), + val_bckp.data(), + key.size()); ASSERT_EQ(is_kv_equivalent, true); } -REGISTER_TYPED_TEST_SUITE_P(simdkvsort, test_kvsort_ascending, test_kvsort_descending, test_kvselect_ascending, test_kvselect_descending, test_kvpartial_sort_ascending, test_kvpartial_sort_descending, test_validator); +REGISTER_TYPED_TEST_SUITE_P(simdkvsort, + test_kvsort_ascending, + test_kvsort_descending, + test_kvselect_ascending, + test_kvselect_descending, + test_kvpartial_sort_ascending, + test_kvpartial_sort_descending, + test_validator); #define CREATE_TUPLES(type) \ std::tuple, std::tuple, \ From f57e0ac8a5ff3ca15ae50fd69962c9c66d84817e Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Wed, 8 May 2024 11:57:01 -0700 Subject: [PATCH 7/8] Simplified scalar logic for kv-select and key-value partial-sort --- lib/x86simdsort-scalar.h | 19 ++++------- tests/test-keyvalue.cpp | 69 +--------------------------------------- 2 files changed, 7 insertions(+), 81 deletions(-) diff --git a/lib/x86simdsort-scalar.h b/lib/x86simdsort-scalar.h index dadd1eef..3dc737ca 100644 --- a/lib/x86simdsort-scalar.h +++ b/lib/x86simdsort-scalar.h @@ -115,13 +115,9 @@ namespace scalar { bool hasnan, bool descending) { - if (k == 0) return; - // Note that this does a full partial sort, not just a select - std::vector arg = argsort(key, arrsize, hasnan, descending); - //arg.resize(k); - - utils::apply_permutation_in_place(key, arg); - utils::apply_permutation_in_place(val, arg); + // Note that this does a full kv-sort + UNUSED(k); + keyvalue_qsort(key, val, arrsize, hasnan, descending); } template void keyvalue_partial_sort(T1 *key, @@ -131,12 +127,9 @@ namespace scalar { bool hasnan, bool descending) { - if (k == 0) return; - std::vector arg = argsort(key, arrsize, hasnan, descending); - //arg.resize(k); - - utils::apply_permutation_in_place(key, arg); - utils::apply_permutation_in_place(val, arg); + // Note that this does a full kv-sort + UNUSED(k); + keyvalue_qsort(key, val, arrsize, hasnan, descending); } } // namespace scalar diff --git a/tests/test-keyvalue.cpp b/tests/test-keyvalue.cpp index c6ad960e..d3a796f1 100644 --- a/tests/test-keyvalue.cpp +++ b/tests/test-keyvalue.cpp @@ -369,80 +369,13 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_descending) } } -TYPED_TEST_P(simdkvsort, test_validator) -{ - // Tests a few edge cases to verify the tests are working correctly and identifying it as functional - using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type; - using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type; - - bool is_kv_equivalent; - - std::vector key = {0, 0, 1, 1}; - std::vector val = {1, 2, 3, 4}; - std::vector key_bckp = key; - std::vector val_bckp = val; - - // Duplicate keys, but otherwise exactly identical - is_kv_equivalent = is_kv_sorted(key.data(), - val.data(), - key_bckp.data(), - val_bckp.data(), - key.size()); - ASSERT_EQ(is_kv_equivalent, true); - - val = {2, 1, 4, 3}; - - // Now values are backwards, but this is still fine - is_kv_equivalent = is_kv_sorted(key.data(), - val.data(), - key_bckp.data(), - val_bckp.data(), - key.size()); - ASSERT_EQ(is_kv_equivalent, true); - - val = {1, 3, 2, 4}; - - // Now values are mixed up, should fail - is_kv_equivalent = is_kv_sorted(key.data(), - val.data(), - key_bckp.data(), - val_bckp.data(), - key.size()); - ASSERT_EQ(is_kv_equivalent, false); - - val = {1, 2, 3, 4}; - key = {0, 0, 0, 0}; - - // Now keys are messed up, should fail - is_kv_equivalent = is_kv_sorted(key.data(), - val.data(), - key_bckp.data(), - val_bckp.data(), - key.size()); - ASSERT_EQ(is_kv_equivalent, false); - - key = {0, 0, 0, 0, 0, 0}; - key_bckp = key; - val_bckp = {1, 2, 3, 4, 5, 6}; - val = {4, 3, 1, 6, 5, 2}; - - // All keys identical, simply reordered values - is_kv_equivalent = is_kv_sorted(key.data(), - val.data(), - key_bckp.data(), - val_bckp.data(), - key.size()); - ASSERT_EQ(is_kv_equivalent, true); -} - REGISTER_TYPED_TEST_SUITE_P(simdkvsort, test_kvsort_ascending, test_kvsort_descending, test_kvselect_ascending, test_kvselect_descending, test_kvpartial_sort_ascending, - test_kvpartial_sort_descending, - test_validator); + test_kvpartial_sort_descending); #define CREATE_TUPLES(type) \ std::tuple, std::tuple, \ From f436aae8fe33a41b5c220d5fe5a51b4745766bf0 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 20 May 2024 14:42:08 -0700 Subject: [PATCH 8/8] Reformat key-value dispatch code --- lib/x86simdsort.cpp | 179 ++++++++++++++++++-------------------------- 1 file changed, 72 insertions(+), 107 deletions(-) diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index e01a86f1..a5bbc578 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -129,10 +129,62 @@ namespace x86simdsort { } \ } -#define DISPATCH_KEYVALUE_SORT(TYPE1, TYPE2, ISA) \ - static void(CAT(CAT(*internal_kv_qsort_, TYPE1), TYPE2))( \ +#define ISA_LIST(...) \ + std::initializer_list \ + { \ + __VA_ARGS__ \ + } + +#ifdef __FLT16_MAX__ +DISPATCH(qsort, _Float16, ISA_LIST("avx512_spr")) +DISPATCH(qselect, _Float16, ISA_LIST("avx512_spr")) +DISPATCH(partial_qsort, _Float16, ISA_LIST("avx512_spr")) +DISPATCH(argsort, _Float16, ISA_LIST("none")) +DISPATCH(argselect, _Float16, ISA_LIST("none")) +#endif + +#define DISPATCH_ALL(func, ISA_16BIT, ISA_32BIT, ISA_64BIT) \ + DISPATCH(func, uint16_t, ISA_16BIT) \ + DISPATCH(func, int16_t, ISA_16BIT) \ + DISPATCH(func, float, ISA_32BIT) \ + DISPATCH(func, int32_t, ISA_32BIT) \ + DISPATCH(func, uint32_t, ISA_32BIT) \ + DISPATCH(func, int64_t, ISA_64BIT) \ + DISPATCH(func, uint64_t, ISA_64BIT) \ + DISPATCH(func, double, ISA_64BIT) + +DISPATCH_ALL(qsort, + (ISA_LIST("avx512_icl")), + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx", "avx2"))) +DISPATCH_ALL(qselect, + (ISA_LIST("avx512_icl")), + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx", "avx2"))) +DISPATCH_ALL(partial_qsort, + (ISA_LIST("avx512_icl")), + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx", "avx2"))) +DISPATCH_ALL(argsort, + (ISA_LIST("none")), + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx", "avx2"))) +DISPATCH_ALL(argselect, + (ISA_LIST("none")), + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx", "avx2"))) + +/* Key-Value methods */ +#define DECLARE_ALL_KEYVALUE_METHODS(TYPE1, TYPE2) \ + static void(CAT(CAT(*internal_keyvalue_qsort_, TYPE1), TYPE2))( \ TYPE1 *, TYPE2 *, size_t, bool, bool) \ = NULL; \ + static void(CAT(CAT(*internal_keyvalue_select_, TYPE1), TYPE2))( \ + TYPE1 *, TYPE2 *, size_t, size_t, bool, bool) \ + = NULL; \ + static void(CAT(CAT(*internal_keyvalue_partial_sort_, TYPE1), TYPE2))( \ + TYPE1 *, TYPE2 *, size_t, size_t, bool, bool) \ + = NULL; \ template <> \ void keyvalue_qsort(TYPE1 *key, \ TYPE2 *val, \ @@ -140,34 +192,9 @@ namespace x86simdsort { bool hasnan, \ bool descending) \ { \ - (CAT(CAT(*internal_kv_qsort_, TYPE1), TYPE2))( \ + (CAT(CAT(*internal_keyvalue_qsort_, TYPE1), TYPE2))( \ key, val, arrsize, hasnan, descending); \ } \ - static __attribute__((constructor)) void CAT( \ - CAT(resolve_keyvalue_qsort_, TYPE1), TYPE2)(void) \ - { \ - CAT(CAT(internal_kv_qsort_, TYPE1), TYPE2) \ - = &xss::scalar::keyvalue_qsort; \ - __builtin_cpu_init(); \ - std::string_view preferred_cpu = find_preferred_cpu(ISA); \ - if constexpr (dispatch_requested("avx512", ISA)) { \ - if (preferred_cpu.find("avx512") != std::string_view::npos) { \ - CAT(CAT(internal_kv_qsort_, TYPE1), TYPE2) \ - = &xss::avx512::keyvalue_qsort; \ - return; \ - } \ - } \ - if constexpr (dispatch_requested("avx2", ISA)) { \ - if (preferred_cpu.find("avx2") != std::string_view::npos) { \ - CAT(CAT(internal_kv_qsort_, TYPE1), TYPE2) \ - = &xss::avx2::keyvalue_qsort; \ - return; \ - } \ - } \ - } \ - static void(CAT(CAT(*internal_kv_select_, TYPE1), TYPE2))( \ - TYPE1 *, TYPE2 *, size_t, size_t, bool, bool) \ - = NULL; \ template <> \ void keyvalue_select(TYPE1 *key, \ TYPE2 *val, \ @@ -176,34 +203,9 @@ namespace x86simdsort { bool hasnan, \ bool descending) \ { \ - (CAT(CAT(*internal_kv_select_, TYPE1), TYPE2))( \ + (CAT(CAT(*internal_keyvalue_select_, TYPE1), TYPE2))( \ key, val, k, arrsize, hasnan, descending); \ } \ - static __attribute__((constructor)) void CAT( \ - CAT(resolve_keyvalue_select_, TYPE1), TYPE2)(void) \ - { \ - CAT(CAT(internal_kv_select_, TYPE1), TYPE2) \ - = &xss::scalar::keyvalue_select; \ - __builtin_cpu_init(); \ - std::string_view preferred_cpu = find_preferred_cpu(ISA); \ - if constexpr (dispatch_requested("avx512", ISA)) { \ - if (preferred_cpu.find("avx512") != std::string_view::npos) { \ - CAT(CAT(internal_kv_select_, TYPE1), TYPE2) \ - = &xss::avx512::keyvalue_select; \ - return; \ - } \ - } \ - if constexpr (dispatch_requested("avx2", ISA)) { \ - if (preferred_cpu.find("avx2") != std::string_view::npos) { \ - CAT(CAT(internal_kv_select_, TYPE1), TYPE2) \ - = &xss::avx2::keyvalue_select; \ - return; \ - } \ - } \ - } \ - static void(CAT(CAT(*internal_kv_partial_sort_, TYPE1), TYPE2))( \ - TYPE1 *, TYPE2 *, size_t, size_t, bool, bool) \ - = NULL; \ template <> \ void keyvalue_partial_sort(TYPE1 *key, \ TYPE2 *val, \ @@ -212,76 +214,39 @@ namespace x86simdsort { bool hasnan, \ bool descending) \ { \ - (CAT(CAT(*internal_kv_partial_sort_, TYPE1), TYPE2))( \ + (CAT(CAT(*internal_keyvalue_partial_sort_, TYPE1), TYPE2))( \ key, val, k, arrsize, hasnan, descending); \ - } \ + } + +#define DISPATCH_KV_FUNC(func, TYPE1, TYPE2, ISA) \ static __attribute__((constructor)) void CAT( \ - CAT(resolve_keyvalue_partial_sort_, TYPE1), TYPE2)(void) \ + CAT(CAT(CAT(resolve_, func), _), TYPE1), TYPE2)(void) \ { \ - CAT(CAT(internal_kv_partial_sort_, TYPE1), TYPE2) \ - = &xss::scalar::keyvalue_partial_sort; \ + CAT(CAT(CAT(CAT(internal_, func), _), TYPE1), TYPE2) \ + = &xss::scalar::func; \ __builtin_cpu_init(); \ std::string_view preferred_cpu = find_preferred_cpu(ISA); \ if constexpr (dispatch_requested("avx512", ISA)) { \ if (preferred_cpu.find("avx512") != std::string_view::npos) { \ - CAT(CAT(internal_kv_partial_sort_, TYPE1), TYPE2) \ - = &xss::avx512::keyvalue_partial_sort; \ + CAT(CAT(CAT(CAT(internal_, func), _), TYPE1), TYPE2) \ + = &xss::avx512::func; \ return; \ } \ } \ if constexpr (dispatch_requested("avx2", ISA)) { \ if (preferred_cpu.find("avx2") != std::string_view::npos) { \ - CAT(CAT(internal_kv_partial_sort_, TYPE1), TYPE2) \ - = &xss::avx2::keyvalue_partial_sort; \ + CAT(CAT(CAT(CAT(internal_, func), _), TYPE1), TYPE2) \ + = &xss::avx2::func; \ return; \ } \ } \ } -#define ISA_LIST(...) \ - std::initializer_list \ - { \ - __VA_ARGS__ \ - } - -#ifdef __FLT16_MAX__ -DISPATCH(qsort, _Float16, ISA_LIST("avx512_spr")) -DISPATCH(qselect, _Float16, ISA_LIST("avx512_spr")) -DISPATCH(partial_qsort, _Float16, ISA_LIST("avx512_spr")) -DISPATCH(argsort, _Float16, ISA_LIST("none")) -DISPATCH(argselect, _Float16, ISA_LIST("none")) -#endif - -#define DISPATCH_ALL(func, ISA_16BIT, ISA_32BIT, ISA_64BIT) \ - DISPATCH(func, uint16_t, ISA_16BIT) \ - DISPATCH(func, int16_t, ISA_16BIT) \ - DISPATCH(func, float, ISA_32BIT) \ - DISPATCH(func, int32_t, ISA_32BIT) \ - DISPATCH(func, uint32_t, ISA_32BIT) \ - DISPATCH(func, int64_t, ISA_64BIT) \ - DISPATCH(func, uint64_t, ISA_64BIT) \ - DISPATCH(func, double, ISA_64BIT) - -DISPATCH_ALL(qsort, - (ISA_LIST("avx512_icl")), - (ISA_LIST("avx512_skx", "avx2")), - (ISA_LIST("avx512_skx", "avx2"))) -DISPATCH_ALL(qselect, - (ISA_LIST("avx512_icl")), - (ISA_LIST("avx512_skx", "avx2")), - (ISA_LIST("avx512_skx", "avx2"))) -DISPATCH_ALL(partial_qsort, - (ISA_LIST("avx512_icl")), - (ISA_LIST("avx512_skx", "avx2")), - (ISA_LIST("avx512_skx", "avx2"))) -DISPATCH_ALL(argsort, - (ISA_LIST("none")), - (ISA_LIST("avx512_skx", "avx2")), - (ISA_LIST("avx512_skx", "avx2"))) -DISPATCH_ALL(argselect, - (ISA_LIST("none")), - (ISA_LIST("avx512_skx", "avx2")), - (ISA_LIST("avx512_skx", "avx2"))) +#define DISPATCH_KEYVALUE_SORT(TYPE1, TYPE2, ISA) \ + DECLARE_ALL_KEYVALUE_METHODS(TYPE1, TYPE2) \ + DISPATCH_KV_FUNC(keyvalue_qsort, TYPE1, TYPE2, ISA) \ + DISPATCH_KV_FUNC(keyvalue_select, TYPE1, TYPE2, ISA) \ + DISPATCH_KV_FUNC(keyvalue_partial_sort, TYPE1, TYPE2, ISA) #define DISPATCH_KEYVALUE_SORT_FORTYPE(type) \ DISPATCH_KEYVALUE_SORT(type, uint64_t, (ISA_LIST("avx512_skx", "avx2"))) \