Skip to content

Commit b0daf91

Browse files
author
Raghuveer Devulapalli
committed
get rid of global argtype definition
1 parent e7e452f commit b0daf91

File tree

2 files changed

+111
-93
lines changed

2 files changed

+111
-93
lines changed

src/avx512-64bit-argsort.hpp

Lines changed: 110 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -65,24 +65,15 @@ std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right)
6565
});
6666
}
6767

68-
/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of
69-
* undefined template 'zmm_vector<unsigned long>'*/
70-
#ifdef __APPLE__
71-
using argtype = typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
72-
ymm_vector<uint32_t>,
73-
zmm_vector<uint64_t>>::type;
74-
#else
75-
using argtype = typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
76-
ymm_vector<arrsize_t>,
77-
zmm_vector<arrsize_t>>::type;
78-
#endif
79-
using argreg_t = typename argtype::reg_t;
80-
8168
/*
8269
* Parition one ZMM register based on the pivot and returns the index of the
8370
* last element that is less than equal to the pivot.
8471
*/
85-
template <typename vtype, typename type_t, typename reg_t>
72+
template <typename vtype,
73+
typename argtype,
74+
typename type_t = typename vtype::type_t,
75+
typename reg_t = typename vtype::reg_t,
76+
typename argreg_t = typename argtype::reg_t>
8677
X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg,
8778
arrsize_t left,
8879
arrsize_t right,
@@ -107,7 +98,11 @@ X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg,
10798
* Parition an array based on the pivot and returns the index of the
10899
* last element that is less than equal to the pivot.
109100
*/
110-
template <typename vtype, typename type_t>
101+
template <typename vtype,
102+
typename argtype,
103+
typename type_t = typename vtype::type_t,
104+
typename reg_t = typename vtype::reg_t,
105+
typename argreg_t = typename argtype::reg_t>
111106
X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr,
112107
arrsize_t *arg,
113108
arrsize_t left,
@@ -131,22 +126,22 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr,
131126
if (left == right)
132127
return left; /* less than vtype::numlanes elements in the array */
133128

134-
using reg_t = typename vtype::reg_t;
135129
reg_t pivot_vec = vtype::set1(pivot);
136130
reg_t min_vec = vtype::set1(*smallest);
137131
reg_t max_vec = vtype::set1(*biggest);
138132

139133
if (right - left == vtype::numlanes) {
140134
argreg_t argvec = argtype::loadu(arg + left);
141135
reg_t vec = vtype::i64gather(arr, arg + left);
142-
int32_t amount_gt_pivot = partition_vec<vtype>(arg,
143-
left,
144-
left + vtype::numlanes,
145-
argvec,
146-
vec,
147-
pivot_vec,
148-
&min_vec,
149-
&max_vec);
136+
int32_t amount_gt_pivot
137+
= partition_vec<vtype, argtype>(arg,
138+
left,
139+
left + vtype::numlanes,
140+
argvec,
141+
vec,
142+
pivot_vec,
143+
&min_vec,
144+
&max_vec);
150145
*smallest = vtype::reducemin(min_vec);
151146
*biggest = vtype::reducemax(max_vec);
152147
return left + (vtype::numlanes - amount_gt_pivot);
@@ -183,46 +178,49 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr,
183178
}
184179
// partition the current vector and save it on both sides of the array
185180
int32_t amount_gt_pivot
186-
= partition_vec<vtype>(arg,
187-
l_store,
188-
r_store + vtype::numlanes,
189-
arg_vec,
190-
curr_vec,
191-
pivot_vec,
192-
&min_vec,
193-
&max_vec);
181+
= partition_vec<vtype, argtype>(arg,
182+
l_store,
183+
r_store + vtype::numlanes,
184+
arg_vec,
185+
curr_vec,
186+
pivot_vec,
187+
&min_vec,
188+
&max_vec);
194189
;
195190
r_store -= amount_gt_pivot;
196191
l_store += (vtype::numlanes - amount_gt_pivot);
197192
}
198193

199194
/* partition and save vec_left and vec_right */
200-
int32_t amount_gt_pivot = partition_vec<vtype>(arg,
201-
l_store,
202-
r_store + vtype::numlanes,
203-
argvec_left,
204-
vec_left,
205-
pivot_vec,
206-
&min_vec,
207-
&max_vec);
195+
int32_t amount_gt_pivot
196+
= partition_vec<vtype, argtype>(arg,
197+
l_store,
198+
r_store + vtype::numlanes,
199+
argvec_left,
200+
vec_left,
201+
pivot_vec,
202+
&min_vec,
203+
&max_vec);
208204
l_store += (vtype::numlanes - amount_gt_pivot);
209-
amount_gt_pivot = partition_vec<vtype>(arg,
210-
l_store,
211-
l_store + vtype::numlanes,
212-
argvec_right,
213-
vec_right,
214-
pivot_vec,
215-
&min_vec,
216-
&max_vec);
205+
amount_gt_pivot = partition_vec<vtype, argtype>(arg,
206+
l_store,
207+
l_store + vtype::numlanes,
208+
argvec_right,
209+
vec_right,
210+
pivot_vec,
211+
&min_vec,
212+
&max_vec);
217213
l_store += (vtype::numlanes - amount_gt_pivot);
218214
*smallest = vtype::reducemin(min_vec);
219215
*biggest = vtype::reducemax(max_vec);
220216
return l_store;
221217
}
222218

223219
template <typename vtype,
220+
typename argtype,
224221
int num_unroll,
225-
typename type_t = typename vtype::type_t>
222+
typename type_t = typename vtype::type_t,
223+
typename argreg_t = typename argtype::reg_t>
226224
X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr,
227225
arrsize_t *arg,
228226
arrsize_t left,
@@ -232,7 +230,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr,
232230
type_t *biggest)
233231
{
234232
if (right - left <= 8 * num_unroll * vtype::numlanes) {
235-
return partition_avx512<vtype>(
233+
return partition_avx512<vtype, argtype>(
236234
arr, arg, left, right, pivot, smallest, biggest);
237235
}
238236
/* make array length divisible by vtype::numlanes , shortening the array */
@@ -305,14 +303,14 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr,
305303
X86_SIMD_SORT_UNROLL_LOOP(8)
306304
for (int ii = 0; ii < num_unroll; ++ii) {
307305
int32_t amount_gt_pivot
308-
= partition_vec<vtype>(arg,
309-
l_store,
310-
r_store + vtype::numlanes,
311-
arg_vec[ii],
312-
curr_vec[ii],
313-
pivot_vec,
314-
&min_vec,
315-
&max_vec);
306+
= partition_vec<vtype, argtype>(arg,
307+
l_store,
308+
r_store + vtype::numlanes,
309+
arg_vec[ii],
310+
curr_vec[ii],
311+
pivot_vec,
312+
&min_vec,
313+
&max_vec);
316314
l_store += (vtype::numlanes - amount_gt_pivot);
317315
r_store -= amount_gt_pivot;
318316
}
@@ -322,28 +320,28 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr,
322320
X86_SIMD_SORT_UNROLL_LOOP(8)
323321
for (int ii = 0; ii < num_unroll; ++ii) {
324322
int32_t amount_gt_pivot
325-
= partition_vec<vtype>(arg,
326-
l_store,
327-
r_store + vtype::numlanes,
328-
argvec_left[ii],
329-
vec_left[ii],
330-
pivot_vec,
331-
&min_vec,
332-
&max_vec);
323+
= partition_vec<vtype, argtype>(arg,
324+
l_store,
325+
r_store + vtype::numlanes,
326+
argvec_left[ii],
327+
vec_left[ii],
328+
pivot_vec,
329+
&min_vec,
330+
&max_vec);
333331
l_store += (vtype::numlanes - amount_gt_pivot);
334332
r_store -= amount_gt_pivot;
335333
}
336334
X86_SIMD_SORT_UNROLL_LOOP(8)
337335
for (int ii = 0; ii < num_unroll; ++ii) {
338336
int32_t amount_gt_pivot
339-
= partition_vec<vtype>(arg,
340-
l_store,
341-
r_store + vtype::numlanes,
342-
argvec_right[ii],
343-
vec_right[ii],
344-
pivot_vec,
345-
&min_vec,
346-
&max_vec);
337+
= partition_vec<vtype, argtype>(arg,
338+
l_store,
339+
r_store + vtype::numlanes,
340+
argvec_right[ii],
341+
vec_right[ii],
342+
pivot_vec,
343+
&min_vec,
344+
&max_vec);
347345
l_store += (vtype::numlanes - amount_gt_pivot);
348346
r_store -= amount_gt_pivot;
349347
}
@@ -379,7 +377,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
379377
}
380378
}
381379

382-
template <typename vtype, typename indexType, typename type_t>
380+
template <typename vtype, typename argtype, typename type_t>
383381
X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr,
384382
arrsize_t *arg,
385383
arrsize_t left,
@@ -397,24 +395,24 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr,
397395
* Base case: use bitonic networks to sort arrays <= 64
398396
*/
399397
if (right + 1 - left <= 256) {
400-
argsort_n<vtype, indexType, 256>(
398+
argsort_n<vtype, argtype, 256>(
401399
arr, arg + left, (int32_t)(right + 1 - left));
402400
return;
403401
}
404402
type_t pivot = get_pivot_64bit<vtype>(arr, arg, left, right);
405403
type_t smallest = vtype::type_max();
406404
type_t biggest = vtype::type_min();
407-
arrsize_t pivot_index = partition_avx512_unrolled<vtype, 4>(
405+
arrsize_t pivot_index = partition_avx512_unrolled<vtype, argtype, 4>(
408406
arr, arg, left, right + 1, pivot, &smallest, &biggest);
409407
if (pivot != smallest)
410-
argsort_64bit_<vtype, indexType>(
408+
argsort_64bit_<vtype, argtype>(
411409
arr, arg, left, pivot_index - 1, max_iters - 1);
412410
if (pivot != biggest)
413-
argsort_64bit_<vtype, indexType>(
411+
argsort_64bit_<vtype, argtype>(
414412
arr, arg, pivot_index, right, max_iters - 1);
415413
}
416414

417-
template <typename vtype, typename indexType, typename type_t>
415+
template <typename vtype, typename argtype, typename type_t>
418416
X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
419417
arrsize_t *arg,
420418
arrsize_t pos,
@@ -433,20 +431,20 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
433431
* Base case: use bitonic networks to sort arrays <= 64
434432
*/
435433
if (right + 1 - left <= 256) {
436-
argsort_n<vtype, indexType, 256>(
434+
argsort_n<vtype, argtype, 256>(
437435
arr, arg + left, (int32_t)(right + 1 - left));
438436
return;
439437
}
440438
type_t pivot = get_pivot_64bit<vtype>(arr, arg, left, right);
441439
type_t smallest = vtype::type_max();
442440
type_t biggest = vtype::type_min();
443-
arrsize_t pivot_index = partition_avx512_unrolled<vtype, 4>(
441+
arrsize_t pivot_index = partition_avx512_unrolled<vtype, argtype, 4>(
444442
arr, arg, left, right + 1, pivot, &smallest, &biggest);
445443
if ((pivot != smallest) && (pos < pivot_index))
446-
argselect_64bit_<vtype, indexType>(
444+
argselect_64bit_<vtype, argtype>(
447445
arr, arg, pos, left, pivot_index - 1, max_iters - 1);
448446
else if ((pivot != biggest) && (pos >= pivot_index))
449-
argselect_64bit_<vtype, indexType>(
447+
argselect_64bit_<vtype, argtype>(
450448
arr, arg, pos, pivot_index, right, max_iters - 1);
451449
}
452450

@@ -455,14 +453,24 @@ template <typename T>
455453
X86_SIMD_SORT_INLINE void
456454
avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
457455
{
456+
/* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
458457
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
459458
ymm_vector<T>,
460459
zmm_vector<T>>::type;
461-
using indextype =
462-
typename std::conditional<sizeof(arrsize_t) * vectype::numlanes
463-
== 32,
460+
461+
/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of
462+
* undefined template 'zmm_vector<unsigned long>'*/
463+
#ifdef __APPLE__
464+
using argtype =
465+
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
466+
ymm_vector<uint32_t>,
467+
zmm_vector<uint64_t>>::type;
468+
#else
469+
using argtype =
470+
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
464471
ymm_vector<arrsize_t>,
465472
zmm_vector<arrsize_t>>::type;
473+
#endif
466474

467475
if (arrsize > 1) {
468476
if constexpr (std::is_floating_point_v<T>) {
@@ -472,7 +480,7 @@ avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
472480
}
473481
}
474482
UNUSED(hasnan);
475-
argsort_64bit_<vectype, indextype>(
483+
argsort_64bit_<vectype, argtype>(
476484
arr, arg, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
477485
}
478486
}
@@ -495,14 +503,24 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
495503
arrsize_t arrsize,
496504
bool hasnan = false)
497505
{
506+
/* TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */
498507
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
499508
ymm_vector<T>,
500509
zmm_vector<T>>::type;
501-
using indextype =
502-
typename std::conditional<sizeof(arrsize_t) * vectype::numlanes
503-
== 32,
510+
511+
/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of
512+
* undefined template 'zmm_vector<unsigned long>'*/
513+
#ifdef __APPLE__
514+
using argtype =
515+
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
516+
ymm_vector<uint32_t>,
517+
zmm_vector<uint64_t>>::type;
518+
#else
519+
using argtype =
520+
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
504521
ymm_vector<arrsize_t>,
505522
zmm_vector<arrsize_t>>::type;
523+
#endif
506524

507525
if (arrsize > 1) {
508526
if constexpr (std::is_floating_point_v<T>) {
@@ -512,7 +530,7 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
512530
}
513531
}
514532
UNUSED(hasnan);
515-
argselect_64bit_<vectype, indextype>(
533+
argselect_64bit_<vectype, argtype>(
516534
arr, arg, k, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize));
517535
}
518536
}

src/xss-network-keyvaluesort.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,4 +408,4 @@ X86_SIMD_SORT_INLINE void kvsort_n(typename keyType::type_t *keys,
408408
kvsort_n_vec<keyType, valueType, numVecs>(keys, values, N);
409409
}
410410

411-
#endif
411+
#endif

0 commit comments

Comments
 (0)