@@ -65,24 +65,15 @@ std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right)
6565              });
6666}
6767
68- /*  Workaround for NumPy failed build on macOS x86_64: implicit instantiation of
69-  * undefined template 'zmm_vector<unsigned long>'*/  
70- #ifdef  __APPLE__
71- using  argtype = typename  std::conditional<sizeof (arrsize_t ) == sizeof (int32_t ),
72-                                           ymm_vector<uint32_t >,
73-                                           zmm_vector<uint64_t >>::type;
74- #else 
75- using  argtype = typename  std::conditional<sizeof (arrsize_t ) == sizeof (int32_t ),
76-                                           ymm_vector<arrsize_t >,
77-                                           zmm_vector<arrsize_t >>::type;
78- #endif 
79- using  argreg_t  = typename  argtype::reg_t ;
80- 
8168/* 
8269 * Parition one ZMM register based on the pivot and returns the index of the 
8370 * last element that is less than equal to the pivot. 
8471 */  
85- template  <typename  vtype, typename  type_t , typename  reg_t >
72+ template  <typename  vtype,
73+           typename  argtype,
74+           typename  type_t  = typename  vtype::type_t ,
75+           typename  reg_t  = typename  vtype::reg_t ,
76+           typename  argreg_t  = typename  argtype::reg_t >
8677X86_SIMD_SORT_INLINE int32_t  partition_vec (type_t  *arg,
8778                                           arrsize_t  left,
8879                                           arrsize_t  right,
@@ -107,7 +98,11 @@ X86_SIMD_SORT_INLINE int32_t partition_vec(type_t *arg,
10798 * Parition an array based on the pivot and returns the index of the 
10899 * last element that is less than equal to the pivot. 
109100 */  
110- template  <typename  vtype, typename  type_t >
101+ template  <typename  vtype,
102+           typename  argtype,
103+           typename  type_t  = typename  vtype::type_t ,
104+           typename  reg_t  = typename  vtype::reg_t ,
105+           typename  argreg_t  = typename  argtype::reg_t >
111106X86_SIMD_SORT_INLINE arrsize_t  partition_avx512 (type_t  *arr,
112107                                                arrsize_t  *arg,
113108                                                arrsize_t  left,
@@ -131,22 +126,22 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr,
131126    if  (left == right)
132127        return  left; /*  less than vtype::numlanes elements in the array */ 
133128
134-     using  reg_t  = typename  vtype::reg_t ;
135129    reg_t  pivot_vec = vtype::set1 (pivot);
136130    reg_t  min_vec = vtype::set1 (*smallest);
137131    reg_t  max_vec = vtype::set1 (*biggest);
138132
139133    if  (right - left == vtype::numlanes) {
140134        argreg_t  argvec = argtype::loadu (arg + left);
141135        reg_t  vec = vtype::i64gather (arr, arg + left);
142-         int32_t  amount_gt_pivot = partition_vec<vtype>(arg,
143-                                                        left,
144-                                                        left + vtype::numlanes,
145-                                                        argvec,
146-                                                        vec,
147-                                                        pivot_vec,
148-                                                        &min_vec,
149-                                                        &max_vec);
136+         int32_t  amount_gt_pivot
137+                 = partition_vec<vtype, argtype>(arg,
138+                                                 left,
139+                                                 left + vtype::numlanes,
140+                                                 argvec,
141+                                                 vec,
142+                                                 pivot_vec,
143+                                                 &min_vec,
144+                                                 &max_vec);
150145        *smallest = vtype::reducemin (min_vec);
151146        *biggest = vtype::reducemax (max_vec);
152147        return  left + (vtype::numlanes - amount_gt_pivot);
@@ -183,46 +178,49 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512(type_t *arr,
183178        }
184179        //  partition the current vector and save it on both sides of the array
185180        int32_t  amount_gt_pivot
186-                 = partition_vec<vtype>(arg,
187-                                        l_store,
188-                                        r_store + vtype::numlanes,
189-                                        arg_vec,
190-                                        curr_vec,
191-                                        pivot_vec,
192-                                        &min_vec,
193-                                        &max_vec);
181+                 = partition_vec<vtype, argtype >(arg,
182+                                                  l_store,
183+                                                  r_store + vtype::numlanes,
184+                                                  arg_vec,
185+                                                  curr_vec,
186+                                                  pivot_vec,
187+                                                  &min_vec,
188+                                                  &max_vec);
194189        ;
195190        r_store -= amount_gt_pivot;
196191        l_store += (vtype::numlanes - amount_gt_pivot);
197192    }
198193
199194    /*  partition and save vec_left and vec_right */ 
200-     int32_t  amount_gt_pivot = partition_vec<vtype>(arg,
201-                                                    l_store,
202-                                                    r_store + vtype::numlanes,
203-                                                    argvec_left,
204-                                                    vec_left,
205-                                                    pivot_vec,
206-                                                    &min_vec,
207-                                                    &max_vec);
195+     int32_t  amount_gt_pivot
196+             = partition_vec<vtype, argtype>(arg,
197+                                             l_store,
198+                                             r_store + vtype::numlanes,
199+                                             argvec_left,
200+                                             vec_left,
201+                                             pivot_vec,
202+                                             &min_vec,
203+                                             &max_vec);
208204    l_store += (vtype::numlanes - amount_gt_pivot);
209-     amount_gt_pivot = partition_vec<vtype>(arg,
210-                                            l_store,
211-                                            l_store + vtype::numlanes,
212-                                            argvec_right,
213-                                            vec_right,
214-                                            pivot_vec,
215-                                            &min_vec,
216-                                            &max_vec);
205+     amount_gt_pivot = partition_vec<vtype, argtype >(arg,
206+                                                      l_store,
207+                                                      l_store + vtype::numlanes,
208+                                                      argvec_right,
209+                                                      vec_right,
210+                                                      pivot_vec,
211+                                                      &min_vec,
212+                                                      &max_vec);
217213    l_store += (vtype::numlanes - amount_gt_pivot);
218214    *smallest = vtype::reducemin (min_vec);
219215    *biggest = vtype::reducemax (max_vec);
220216    return  l_store;
221217}
222218
223219template  <typename  vtype,
220+           typename  argtype,
224221          int  num_unroll,
225-           typename  type_t  = typename  vtype::type_t >
222+           typename  type_t  = typename  vtype::type_t ,
223+           typename  argreg_t  = typename  argtype::reg_t >
226224X86_SIMD_SORT_INLINE arrsize_t  partition_avx512_unrolled (type_t  *arr,
227225                                                         arrsize_t  *arg,
228226                                                         arrsize_t  left,
@@ -232,7 +230,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr,
232230                                                         type_t  *biggest)
233231{
234232    if  (right - left <= 8  * num_unroll * vtype::numlanes) {
235-         return  partition_avx512<vtype>(
233+         return  partition_avx512<vtype, argtype >(
236234                arr, arg, left, right, pivot, smallest, biggest);
237235    }
238236    /*  make array length divisible by vtype::numlanes , shortening the array */ 
@@ -305,14 +303,14 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr,
305303        X86_SIMD_SORT_UNROLL_LOOP (8 )
306304        for  (int  ii = 0 ; ii < num_unroll; ++ii) {
307305            int32_t  amount_gt_pivot
308-                     = partition_vec<vtype>(arg,
309-                                            l_store,
310-                                            r_store + vtype::numlanes,
311-                                            arg_vec[ii],
312-                                            curr_vec[ii],
313-                                            pivot_vec,
314-                                            &min_vec,
315-                                            &max_vec);
306+                     = partition_vec<vtype, argtype >(arg,
307+                                                      l_store,
308+                                                      r_store + vtype::numlanes,
309+                                                      arg_vec[ii],
310+                                                      curr_vec[ii],
311+                                                      pivot_vec,
312+                                                      &min_vec,
313+                                                      &max_vec);
316314            l_store += (vtype::numlanes - amount_gt_pivot);
317315            r_store -= amount_gt_pivot;
318316        }
@@ -322,28 +320,28 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr,
322320    X86_SIMD_SORT_UNROLL_LOOP (8 )
323321    for  (int  ii = 0 ; ii < num_unroll; ++ii) {
324322        int32_t  amount_gt_pivot
325-                 = partition_vec<vtype>(arg,
326-                                        l_store,
327-                                        r_store + vtype::numlanes,
328-                                        argvec_left[ii],
329-                                        vec_left[ii],
330-                                        pivot_vec,
331-                                        &min_vec,
332-                                        &max_vec);
323+                 = partition_vec<vtype, argtype >(arg,
324+                                                  l_store,
325+                                                  r_store + vtype::numlanes,
326+                                                  argvec_left[ii],
327+                                                  vec_left[ii],
328+                                                  pivot_vec,
329+                                                  &min_vec,
330+                                                  &max_vec);
333331        l_store += (vtype::numlanes - amount_gt_pivot);
334332        r_store -= amount_gt_pivot;
335333    }
336334    X86_SIMD_SORT_UNROLL_LOOP (8 )
337335    for  (int  ii = 0 ; ii < num_unroll; ++ii) {
338336        int32_t  amount_gt_pivot
339-                 = partition_vec<vtype>(arg,
340-                                        l_store,
341-                                        r_store + vtype::numlanes,
342-                                        argvec_right[ii],
343-                                        vec_right[ii],
344-                                        pivot_vec,
345-                                        &min_vec,
346-                                        &max_vec);
337+                 = partition_vec<vtype, argtype >(arg,
338+                                                  l_store,
339+                                                  r_store + vtype::numlanes,
340+                                                  argvec_right[ii],
341+                                                  vec_right[ii],
342+                                                  pivot_vec,
343+                                                  &min_vec,
344+                                                  &max_vec);
347345        l_store += (vtype::numlanes - amount_gt_pivot);
348346        r_store -= amount_gt_pivot;
349347    }
@@ -379,7 +377,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
379377    }
380378}
381379
382- template  <typename  vtype, typename  indexType , typename  type_t >
380+ template  <typename  vtype, typename  argtype , typename  type_t >
383381X86_SIMD_SORT_INLINE void  argsort_64bit_ (type_t  *arr,
384382                                         arrsize_t  *arg,
385383                                         arrsize_t  left,
@@ -397,24 +395,24 @@ X86_SIMD_SORT_INLINE void argsort_64bit_(type_t *arr,
397395     * Base case: use bitonic networks to sort arrays <= 64 
398396     */  
399397    if  (right + 1  - left <= 256 ) {
400-         argsort_n<vtype, indexType , 256 >(
398+         argsort_n<vtype, argtype , 256 >(
401399                arr, arg + left, (int32_t )(right + 1  - left));
402400        return ;
403401    }
404402    type_t  pivot = get_pivot_64bit<vtype>(arr, arg, left, right);
405403    type_t  smallest = vtype::type_max ();
406404    type_t  biggest = vtype::type_min ();
407-     arrsize_t  pivot_index = partition_avx512_unrolled<vtype, 4 >(
405+     arrsize_t  pivot_index = partition_avx512_unrolled<vtype, argtype,  4 >(
408406            arr, arg, left, right + 1 , pivot, &smallest, &biggest);
409407    if  (pivot != smallest)
410-         argsort_64bit_<vtype, indexType >(
408+         argsort_64bit_<vtype, argtype >(
411409                arr, arg, left, pivot_index - 1 , max_iters - 1 );
412410    if  (pivot != biggest)
413-         argsort_64bit_<vtype, indexType >(
411+         argsort_64bit_<vtype, argtype >(
414412                arr, arg, pivot_index, right, max_iters - 1 );
415413}
416414
417- template  <typename  vtype, typename  indexType , typename  type_t >
415+ template  <typename  vtype, typename  argtype , typename  type_t >
418416X86_SIMD_SORT_INLINE void  argselect_64bit_ (type_t  *arr,
419417                                           arrsize_t  *arg,
420418                                           arrsize_t  pos,
@@ -433,20 +431,20 @@ X86_SIMD_SORT_INLINE void argselect_64bit_(type_t *arr,
433431     * Base case: use bitonic networks to sort arrays <= 64 
434432     */  
435433    if  (right + 1  - left <= 256 ) {
436-         argsort_n<vtype, indexType , 256 >(
434+         argsort_n<vtype, argtype , 256 >(
437435                arr, arg + left, (int32_t )(right + 1  - left));
438436        return ;
439437    }
440438    type_t  pivot = get_pivot_64bit<vtype>(arr, arg, left, right);
441439    type_t  smallest = vtype::type_max ();
442440    type_t  biggest = vtype::type_min ();
443-     arrsize_t  pivot_index = partition_avx512_unrolled<vtype, 4 >(
441+     arrsize_t  pivot_index = partition_avx512_unrolled<vtype, argtype,  4 >(
444442            arr, arg, left, right + 1 , pivot, &smallest, &biggest);
445443    if  ((pivot != smallest) && (pos < pivot_index))
446-         argselect_64bit_<vtype, indexType >(
444+         argselect_64bit_<vtype, argtype >(
447445                arr, arg, pos, left, pivot_index - 1 , max_iters - 1 );
448446    else  if  ((pivot != biggest) && (pos >= pivot_index))
449-         argselect_64bit_<vtype, indexType >(
447+         argselect_64bit_<vtype, argtype >(
450448                arr, arg, pos, pivot_index, right, max_iters - 1 );
451449}
452450
@@ -455,14 +453,24 @@ template <typename T>
455453X86_SIMD_SORT_INLINE void 
456454avx512_argsort (T *arr, arrsize_t  *arg, arrsize_t  arrsize, bool  hasnan = false )
457455{
456+     /*  TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ 
458457    using  vectype = typename  std::conditional<sizeof (T) == sizeof (int32_t ),
459458                                              ymm_vector<T>,
460459                                              zmm_vector<T>>::type;
461-     using  indextype =
462-             typename  std::conditional<sizeof (arrsize_t ) * vectype::numlanes
463-                                               == 32 ,
460+ 
461+ /*  Workaround for NumPy failed build on macOS x86_64: implicit instantiation of
462+  * undefined template 'zmm_vector<unsigned long>'*/  
463+ #ifdef  __APPLE__
464+     using  argtype =
465+             typename  std::conditional<sizeof (arrsize_t ) == sizeof (int32_t ),
466+                                       ymm_vector<uint32_t >,
467+                                       zmm_vector<uint64_t >>::type;
468+ #else 
469+     using  argtype =
470+             typename  std::conditional<sizeof (arrsize_t ) == sizeof (int32_t ),
464471                                      ymm_vector<arrsize_t >,
465472                                      zmm_vector<arrsize_t >>::type;
473+ #endif 
466474
467475    if  (arrsize > 1 ) {
468476        if  constexpr  (std::is_floating_point_v<T>) {
@@ -472,7 +480,7 @@ avx512_argsort(T *arr, arrsize_t *arg, arrsize_t arrsize, bool hasnan = false)
472480            }
473481        }
474482        UNUSED (hasnan);
475-         argsort_64bit_<vectype, indextype >(
483+         argsort_64bit_<vectype, argtype >(
476484                arr, arg, 0 , arrsize - 1 , 2  * (arrsize_t )log2 (arrsize));
477485    }
478486}
@@ -495,14 +503,24 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
495503                                           arrsize_t  arrsize,
496504                                           bool  hasnan = false )
497505{
506+     /*  TODO optimization: on 32-bit, use zmm_vector for 32-bit dtype */ 
498507    using  vectype = typename  std::conditional<sizeof (T) == sizeof (int32_t ),
499508                                              ymm_vector<T>,
500509                                              zmm_vector<T>>::type;
501-     using  indextype =
502-             typename  std::conditional<sizeof (arrsize_t ) * vectype::numlanes
503-                                               == 32 ,
510+ 
511+ /*  Workaround for NumPy failed build on macOS x86_64: implicit instantiation of
512+  * undefined template 'zmm_vector<unsigned long>'*/  
513+ #ifdef  __APPLE__
514+     using  argtype =
515+             typename  std::conditional<sizeof (arrsize_t ) == sizeof (int32_t ),
516+                                       ymm_vector<uint32_t >,
517+                                       zmm_vector<uint64_t >>::type;
518+ #else 
519+     using  argtype =
520+             typename  std::conditional<sizeof (arrsize_t ) == sizeof (int32_t ),
504521                                      ymm_vector<arrsize_t >,
505522                                      zmm_vector<arrsize_t >>::type;
523+ #endif 
506524
507525    if  (arrsize > 1 ) {
508526        if  constexpr  (std::is_floating_point_v<T>) {
@@ -512,7 +530,7 @@ X86_SIMD_SORT_INLINE void avx512_argselect(T *arr,
512530            }
513531        }
514532        UNUSED (hasnan);
515-         argselect_64bit_<vectype, indextype >(
533+         argselect_64bit_<vectype, argtype >(
516534                arr, arg, k, 0 , arrsize - 1 , 2  * (arrsize_t )log2 (arrsize));
517535    }
518536}
0 commit comments