11#include " argsort.cuh"
22
3+ #ifdef GGML_CUDA_USE_CUB
4+ # include < cub/cub.cuh>
5+ using namespace cub ;
6+ #endif // GGML_CUDA_USE_CUB
7+
8+ static __global__ void init_indices (int * indices, const int ncols, const int nrows) {
9+ const int col = blockIdx .x * blockDim .x + threadIdx .x ;
10+ const int row = blockIdx .y ;
11+
12+ if (col < ncols && row < nrows) {
13+ indices[row * ncols + col] = col;
14+ }
15+ }
16+
17+ static __global__ void init_offsets (int * offsets, const int ncols, const int nrows) {
18+ const int idx = blockIdx .x * blockDim .x + threadIdx .x ;
19+ if (idx <= nrows) {
20+ offsets[idx] = idx * ncols;
21+ }
22+ }
23+
24+ #ifdef GGML_CUDA_USE_CUB
25+ static void argsort_f32_i32_cuda_cub (ggml_cuda_pool & pool,
26+ const float * x,
27+ int * dst,
28+ const int ncols,
29+ const int nrows,
30+ ggml_sort_order order,
31+ cudaStream_t stream) {
32+ ggml_cuda_pool_alloc<int > temp_indices_alloc (pool, ncols * nrows);
33+ ggml_cuda_pool_alloc<float > temp_keys_alloc (pool, ncols * nrows);
34+ ggml_cuda_pool_alloc<int > offsets_alloc (pool, nrows + 1 );
35+
36+ int * temp_indices = temp_indices_alloc.get ();
37+ float * temp_keys = temp_keys_alloc.get ();
38+ int * d_offsets = offsets_alloc.get ();
39+
40+ static const int block_size = 256 ;
41+ const dim3 grid_size ((ncols + block_size - 1 ) / block_size, nrows);
42+ init_indices<<<grid_size, block_size, 0 , stream>>> (temp_indices, ncols, nrows);
43+
44+ const dim3 offset_grid ((nrows + block_size - 1 ) / block_size);
45+ init_offsets<<<offset_grid, block_size, 0 , stream>>> (d_offsets, ncols, nrows);
46+
47+ cudaMemcpyAsync (temp_keys, x, ncols * nrows * sizeof (float ), cudaMemcpyDeviceToDevice, stream);
48+
49+ size_t temp_storage_bytes = 0 ;
50+
51+ if (order == GGML_SORT_ORDER_ASC) {
52+ DeviceSegmentedRadixSort::SortPairs (nullptr , temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
53+ temp_indices, dst, // values (indices)
54+ ncols * nrows, nrows, // num items, num segments
55+ d_offsets, d_offsets + 1 , 0 , sizeof (float ) * 8 , // all bits
56+ stream);
57+ } else {
58+ DeviceSegmentedRadixSort::SortPairsDescending (nullptr , temp_storage_bytes, temp_keys, temp_keys, temp_indices,
59+ dst, ncols * nrows, nrows, d_offsets, d_offsets + 1 , 0 ,
60+ sizeof (float ) * 8 , stream);
61+ }
62+
63+ ggml_cuda_pool_alloc<uint8_t > temp_storage_alloc (pool, temp_storage_bytes);
64+ void * d_temp_storage = temp_storage_alloc.get ();
65+
66+ if (order == GGML_SORT_ORDER_ASC) {
67+ DeviceSegmentedRadixSort::SortPairs (d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
68+ ncols * nrows, nrows, d_offsets, d_offsets + 1 , 0 , sizeof (float ) * 8 ,
69+ stream);
70+ } else {
71+ DeviceSegmentedRadixSort::SortPairsDescending (d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
72+ temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1 ,
73+ 0 , sizeof (float ) * 8 , stream);
74+ }
75+ }
76+ #endif // GGML_CUDA_USE_CUB
77+
78+ // Bitonic sort implementation
379template <typename T>
480static inline __device__ void ggml_cuda_swap (T & a, T & b) {
581 T tmp = a;
@@ -65,7 +141,12 @@ static int next_power_of_2(int x) {
65141 return n;
66142}
67143
68- static void argsort_f32_i32_cuda (const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
144+ static void argsort_f32_i32_cuda_bitonic (const float * x,
145+ int * dst,
146+ const int ncols,
147+ const int nrows,
148+ ggml_sort_order order,
149+ cudaStream_t stream) {
69150 // bitonic sort requires ncols to be power of 2
70151 const int ncols_pad = next_power_of_2 (ncols);
71152
@@ -77,9 +158,11 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
77158 GGML_ASSERT (shared_mem <= ggml_cuda_info ().devices [ggml_cuda_get_device ()].smpb );
78159
79160 if (order == GGML_SORT_ORDER_ASC) {
80- k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad);
161+ k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
162+ <<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad);
81163 } else if (order == GGML_SORT_ORDER_DESC) {
82- k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad);
164+ k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
165+ <<<block_nums, block_dims, shared_mem, stream>>> (x, dst, ncols, ncols_pad);
83166 } else {
84167 GGML_ABORT (" fatal error" );
85168 }
@@ -100,5 +183,18 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
100183
101184 enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params [0 ];
102185
103- argsort_f32_i32_cuda (src0_d, (int *)dst_d, ncols, nrows, order, stream);
186+ #ifdef GGML_CUDA_USE_CUB
187+ const int ncols_pad = next_power_of_2 (ncols);
188+ const size_t shared_mem = ncols_pad * sizeof (int );
189+ const size_t max_shared_mem = ggml_cuda_info ().devices [ggml_cuda_get_device ()].smpb ;
190+
191+ if (shared_mem > max_shared_mem || ncols > 1024 ) {
192+ ggml_cuda_pool & pool = ctx.pool ();
193+ argsort_f32_i32_cuda_cub (pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
194+ } else {
195+ argsort_f32_i32_cuda_bitonic (src0_d, (int *) dst_d, ncols, nrows, order, stream);
196+ }
197+ #else
198+ argsort_f32_i32_cuda_bitonic (src0_d, (int *) dst_d, ncols, nrows, order, stream);
199+ #endif
104200}
0 commit comments