Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New pivot selection algorithm to better handle many special cases #127

Merged
merged 6 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion src/avx2-32bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ struct avx2_vector<int32_t> {
{
return _mm256_set1_epi32(type_max());
} // TODO: this should broadcast bits as is?
static opmask_t knot_opmask(opmask_t x)
{
auto allOnes = seti(-1, -1, -1, -1, -1, -1, -1, -1);
return _mm256_xor_si256(x, allOnes);
}
static opmask_t get_partial_loadmask(uint64_t num_to_read)
{
auto mask = ((0x1ull << num_to_read) - 0x1ull);
Expand Down Expand Up @@ -204,6 +209,9 @@ struct avx2_vector<int32_t> {
{
return v;
}
static bool all_false(opmask_t k){
return _mm256_movemask_ps(_mm256_castsi256_ps(k)) == 0;
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
Expand Down Expand Up @@ -242,6 +250,11 @@ struct avx2_vector<uint32_t> {
{
return _mm256_set1_epi32(type_max());
}
static opmask_t knot_opmask(opmask_t x)
{
auto allOnes = seti(-1, -1, -1, -1, -1, -1, -1, -1);
return _mm256_xor_si256(x, allOnes);
}
static opmask_t get_partial_loadmask(uint64_t num_to_read)
{
auto mask = ((0x1ull << num_to_read) - 0x1ull);
Expand Down Expand Up @@ -349,6 +362,9 @@ struct avx2_vector<uint32_t> {
{
return v;
}
static bool all_false(opmask_t k){
return _mm256_movemask_ps(_mm256_castsi256_ps(k)) == 0;
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
Expand Down Expand Up @@ -387,7 +403,11 @@ struct avx2_vector<float> {
{
return _mm256_set1_ps(type_max());
}

static opmask_t knot_opmask(opmask_t x)
{
auto allOnes = seti(-1, -1, -1, -1, -1, -1, -1, -1);
return _mm256_xor_si256(x, allOnes);
}
static ymmi_t
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
{
Expand Down Expand Up @@ -514,6 +534,9 @@ struct avx2_vector<float> {
{
return _mm256_castps_si256(v);
}
static bool all_false(opmask_t k){
return _mm256_movemask_ps(_mm256_castsi256_ps(k)) == 0;
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
Expand Down
30 changes: 27 additions & 3 deletions src/avx2-64bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,17 @@ struct avx2_vector<int64_t> {
{
return _mm256_set1_epi64x(type_max());
} // TODO: this should broadcast bits as is?
static opmask_t knot_opmask(opmask_t x)
{
auto allTrue = _mm256_set1_epi64x(0xFFFF'FFFF'FFFF'FFFF);
return _mm256_xor_si256(x, allTrue);
}
static opmask_t get_partial_loadmask(uint64_t num_to_read)
{
auto mask = ((0x1ull << num_to_read) - 0x1ull);
return convert_int_to_avx2_mask_64bit(mask);
}
static ymmi_t seti(int v1, int v2, int v3, int v4)
static ymmi_t seti(int64_t v1, int64_t v2, int64_t v3, int64_t v4)
{
return _mm256_set_epi64x(v1, v2, v3, v4);
}
Expand Down Expand Up @@ -209,6 +214,9 @@ struct avx2_vector<int64_t> {
{
return v;
}
static bool all_false(opmask_t k){
return _mm256_movemask_pd(_mm256_castsi256_pd(k)) == 0;
}
};
template <>
struct avx2_vector<uint64_t> {
Expand Down Expand Up @@ -239,12 +247,17 @@ struct avx2_vector<uint64_t> {
{
return _mm256_set1_epi64x(type_max());
}
static opmask_t knot_opmask(opmask_t x)
{
auto allTrue = _mm256_set1_epi64x(0xFFFF'FFFF'FFFF'FFFF);
return _mm256_xor_si256(x, allTrue);
}
static opmask_t get_partial_loadmask(uint64_t num_to_read)
{
auto mask = ((0x1ull << num_to_read) - 0x1ull);
return convert_int_to_avx2_mask_64bit(mask);
}
static ymmi_t seti(int v1, int v2, int v3, int v4)
static ymmi_t seti(int64_t v1, int64_t v2, int64_t v3, int64_t v4)
{
return _mm256_set_epi64x(v1, v2, v3, v4);
}
Expand Down Expand Up @@ -378,6 +391,9 @@ struct avx2_vector<uint64_t> {
{
return v;
}
static bool all_false(opmask_t k){
return _mm256_movemask_pd(_mm256_castsi256_pd(k)) == 0;
}
};

/*
Expand Down Expand Up @@ -421,6 +437,11 @@ struct avx2_vector<double> {
{
return _mm256_set1_pd(type_max());
}
static opmask_t knot_opmask(opmask_t x)
{
auto allTrue = _mm256_set1_epi64x(0xFFFF'FFFF'FFFF'FFFF);
return _mm256_xor_si256(x, allTrue);
}
static opmask_t get_partial_loadmask(uint64_t num_to_read)
{
auto mask = ((0x1ull << num_to_read) - 0x1ull);
Expand All @@ -440,7 +461,7 @@ struct avx2_vector<double> {
static_assert(type == (0x01 | 0x80), "should not reach here");
}
}
static ymmi_t seti(int v1, int v2, int v3, int v4)
static ymmi_t seti(int64_t v1, int64_t v2, int64_t v3, int64_t v4)
{
return _mm256_set_epi64x(v1, v2, v3, v4);
}
Expand Down Expand Up @@ -571,6 +592,9 @@ struct avx2_vector<double> {
{
return _mm256_castpd_si256(v);
}
static bool all_false(opmask_t k){
return _mm256_movemask_pd(_mm256_castsi256_pd(k)) == 0;
}
};

struct avx2_64bit_swizzle_ops {
Expand Down
21 changes: 21 additions & 0 deletions src/avx512-16bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ struct zmm_vector<float16> {
exp_eq, mant_x, mant_y, _MM_CMPINT_NLT);
return _kxor_mask32(mask_ge, neg);
}
static opmask_t eq(reg_t x, reg_t y)
{
return _mm512_cmpeq_epu16_mask(x, y);
}
static opmask_t get_partial_loadmask(uint64_t num_to_read)
{
return ((0x1ull << num_to_read) - 0x1ull);
Expand Down Expand Up @@ -186,6 +190,9 @@ struct zmm_vector<float16> {
{
return v;
}
static bool all_false(opmask_t k){
return k == 0;
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
Expand Down Expand Up @@ -238,6 +245,10 @@ struct zmm_vector<int16_t> {
{
return _mm512_cmp_epi16_mask(x, y, _MM_CMPINT_NLT);
}
static opmask_t eq(reg_t x, reg_t y)
{
return _mm512_cmpeq_epi16_mask(x, y);
}
static opmask_t get_partial_loadmask(uint64_t num_to_read)
{
return ((0x1ull << num_to_read) - 0x1ull);
Expand Down Expand Up @@ -323,6 +334,9 @@ struct zmm_vector<int16_t> {
{
return v;
}
static bool all_false(opmask_t k){
return k == 0;
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
Expand Down Expand Up @@ -374,6 +388,10 @@ struct zmm_vector<uint16_t> {
{
return _mm512_cmp_epu16_mask(x, y, _MM_CMPINT_NLT);
}
static opmask_t eq(reg_t x, reg_t y)
{
return _mm512_cmpeq_epu16_mask(x, y);
}
static opmask_t get_partial_loadmask(uint64_t num_to_read)
{
return ((0x1ull << num_to_read) - 0x1ull);
Expand Down Expand Up @@ -457,6 +475,9 @@ struct zmm_vector<uint16_t> {
{
return v;
}
static bool all_false(opmask_t k){
return k == 0;
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
Expand Down
9 changes: 9 additions & 0 deletions src/avx512-32bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ struct zmm_vector<int32_t> {
{
return v;
}
static bool all_false(opmask_t k){
return k == 0;
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
Expand Down Expand Up @@ -377,6 +380,9 @@ struct zmm_vector<uint32_t> {
{
return v;
}
static bool all_false(opmask_t k){
return k == 0;
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
Expand Down Expand Up @@ -570,6 +576,9 @@ struct zmm_vector<float> {
{
return _mm512_castps_si512(v);
}
static bool all_false(opmask_t k){
return k == 0;
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
Expand Down
9 changes: 9 additions & 0 deletions src/avx512-64bit-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,9 @@ struct zmm_vector<int64_t> {
{
return v;
}
static bool all_false(opmask_t k){
return k == 0;
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
Expand Down Expand Up @@ -903,6 +906,9 @@ struct zmm_vector<uint64_t> {
{
return v;
}
static bool all_false(opmask_t k){
return k == 0;
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
Expand Down Expand Up @@ -1093,6 +1099,9 @@ struct zmm_vector<double> {
{
return _mm512_castpd_si512(v);
}
static bool all_false(opmask_t k){
return k == 0;
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
Expand Down
7 changes: 7 additions & 0 deletions src/avx512fp16-16bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ struct zmm_vector<_Float16> {
{
return _mm512_cmp_ph_mask(x, y, _CMP_GE_OQ);
}
static opmask_t eq(reg_t x, reg_t y)
{
return _mm512_cmp_ph_mask(x, y, _CMP_EQ_OQ);
}
static opmask_t get_partial_loadmask(uint64_t num_to_read)
{
return ((0x1ull << num_to_read) - 0x1ull);
Expand Down Expand Up @@ -150,6 +154,9 @@ struct zmm_vector<_Float16> {
{
return _mm512_castph_si512(v);
}
static bool all_false(opmask_t k){
return k == 0;
}
static int double_compressstore(type_t *left_addr,
type_t *right_addr,
opmask_t k,
Expand Down
3 changes: 3 additions & 0 deletions src/xss-common-includes.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,7 @@ struct avx2_half_vector;

enum class simd_type : int { AVX2, AVX512 };

template <typename vtype, typename T = typename vtype::type_t>
X86_SIMD_SORT_INLINE bool comparison_func(const T &a, const T &b);

#endif // XSS_COMMON_INCLUDES
13 changes: 10 additions & 3 deletions src/xss-common-qsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ X86_SIMD_SORT_INLINE bool array_has_nan(type_t *arr, arrsize_t size)
else {
in = vtype::loadu(arr + ii);
}
auto nanmask = vtype::convert_mask_to_int(vtype::template fpclass<0x01 | 0x80>(in));
auto nanmask = vtype::convert_mask_to_int(
vtype::template fpclass<0x01 | 0x80>(in));
if (nanmask != 0x00) {
found_nan = true;
break;
Expand Down Expand Up @@ -136,7 +137,7 @@ X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T *arr, arrsize_t size)
return size - count - 1;
}

template <typename vtype, typename T = typename vtype::type_t>
template <typename vtype, typename T>
X86_SIMD_SORT_INLINE bool comparison_func(const T &a, const T &b)
{
return a < b;
Expand Down Expand Up @@ -499,14 +500,20 @@ qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters)
return;
}

type_t pivot = get_pivot_blocks<vtype, type_t>(arr, left, right);
auto pivot_result = get_pivot_smart<vtype, type_t>(arr, left, right);
type_t pivot = pivot_result.pivot;

if (pivot_result.result == pivot_result_t::Sorted) { return; }

type_t smallest = vtype::type_max();
type_t biggest = vtype::type_min();

arrsize_t pivot_index
= partition_avx512_unrolled<vtype, vtype::partition_unroll_factor>(
arr, left, right + 1, pivot, &smallest, &biggest);

if (pivot_result.result == pivot_result_t::Only2Values) { return; }

if (pivot != smallest)
qsort_<vtype>(arr, left, pivot_index - 1, max_iters - 1);
if (pivot != biggest) qsort_<vtype>(arr, pivot_index, right, max_iters - 1);
Expand Down
10 changes: 4 additions & 6 deletions src/xss-network-keyvaluesort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,8 @@ bitonic_fullmerge_n_vec(typename keyType::reg_t *keys,
}

template <typename keyType, typename indexType, int numVecs>
X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys,
arrsize_t *indices,
int N)
X86_SIMD_SORT_INLINE void
argsort_n_vec(typename keyType::type_t *keys, arrsize_t *indices, int N)
{
using kreg_t = typename keyType::reg_t;
using ireg_t = typename indexType::reg_t;
Expand Down Expand Up @@ -586,9 +585,8 @@ X86_SIMD_SORT_INLINE void kvsort_n_vec(typename keyType::type_t *keys,
}

template <typename keyType, typename indexType, int maxN>
X86_SIMD_SORT_INLINE void argsort_n(typename keyType::type_t *keys,
arrsize_t *indices,
int N)
X86_SIMD_SORT_INLINE void
argsort_n(typename keyType::type_t *keys, arrsize_t *indices, int N)
{
static_assert(keyType::numlanes == indexType::numlanes,
"invalid pairing of value/index types");
Expand Down
Loading
Loading