@@ -129,7 +129,8 @@ struct ring_buffer {
129129};
130130
131131// writes result in res, does not mutate cur
132- static void llama_token_data_array_sort (const llama_token_data_array & cur, int k, std::vector<llama_token_data> & data) {
132+ // reduces the size of cur_p to npartial, keeping only the top npartial elements
133+ static void llama_token_data_array_partial_sort (const llama_token_data_array & cur, int npartial, std::vector<llama_token_data> & res) {
133134 static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
134135 return a.logit > b.logit ;
135136 };
@@ -158,12 +159,12 @@ static void llama_token_data_array_sort(const llama_token_data_array & cur, int
158159 int ib = nbuckets - 1 ;
159160 for ( ; ib >= 0 ; --ib) {
160161 nhave += histo[ib];
161- if (nhave >= k ) {
162+ if (nhave >= npartial ) {
162163 break ;
163164 }
164165 }
165- data .resize (nhave);
166- auto * ptr = data .data ();
166+ res .resize (nhave);
167+ auto * ptr = res .data ();
167168 bucket_ptrs.reserve (nbuckets - ib);
168169 for (int j = nbuckets - 1 ; j >= ib; --j) {
169170 bucket_ptrs.push_back (ptr);
@@ -176,32 +177,39 @@ static void llama_token_data_array_sort(const llama_token_data_array & cur, int
176177 }
177178 }
178179
179- ptr = data .data ();
180+ ptr = res .data ();
180181 int ndone = 0 ;
181182 for (int j = nbuckets - 1 ; j > ib; --j) {
182183 std::sort (ptr, ptr + histo[j], comp);
183184 ptr += histo[j];
184185 ndone += histo[j];
185186 }
186- std::partial_sort (ptr, ptr + k - ndone, ptr + histo[ib], comp);
187+ std::partial_sort (ptr, ptr + npartial - ndone, ptr + histo[ib], comp);
187188}
188189
189- // buf is a helper buffer that can optionally be utilized
190- static void llama_token_data_array_sort_inplace (llama_token_data_array * cur_p, int k ) {
190+ // reduces the size of cur_p to npartial, keeping only the top npartial elements
191+ static void llama_token_data_array_partial_sort_inplace (llama_token_data_array * cur_p, int npartial ) {
191192 static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
192193 return a.logit > b.logit ;
193194 };
194195
195- if (k <= 128 ) {
196- std::partial_sort (cur_p->data , cur_p->data + k, cur_p->data + cur_p->size , comp);
196+ if (npartial <= 128 ) {
197+ std::partial_sort (cur_p->data , cur_p->data + npartial, cur_p->data + cur_p->size , comp);
198+
199+ cur_p->size = npartial;
200+ cur_p->sorted = true ;
201+
197202 return ;
198203 }
199204
200205 std::vector<llama_token_data> tmp;
201206
202- llama_token_data_array_sort (*cur_p, k, tmp);
207+ llama_token_data_array_partial_sort (*cur_p, npartial, tmp);
208+
209+ std::copy (tmp.data (), tmp.data () + npartial, cur_p->data );
203210
204- std::copy (tmp.data (), tmp.data () + k, cur_p->data );
211+ cur_p->size = npartial;
212+ cur_p->sorted = true ;
205213}
206214
207215static int llama_sample_dist (llama_token_data_array * cur_p, std::mt19937 & rng) {
@@ -281,8 +289,7 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_s
281289
282290 // Sort the logits in descending order if requested
283291 if (do_sort && !cur_p->sorted ) {
284- llama_token_data_array_sort_inplace (cur_p, cur_p->size );
285- cur_p->sorted = true ;
292+ llama_token_data_array_partial_sort_inplace (cur_p, cur_p->size );
286293 }
287294
288295 float max_l = cur_p->data [0 ].logit ;
@@ -318,8 +325,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
318325
319326 // Sort scores in descending order
320327 if (!cur_p->sorted ) {
321- llama_token_data_array_sort_inplace (cur_p, k);
322- cur_p->sorted = true ;
328+ llama_token_data_array_partial_sort_inplace (cur_p, k);
323329 }
324330
325331 cur_p->size = k;
@@ -722,12 +728,11 @@ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_d
722728 // if not sorted, try adaptive top-k sorting
723729 if (!cur_p->sorted && cur_p->size > 1024 ) {
724730 k = std::min<size_t >(256 , cur_p->size );
725- llama_token_data_array_sort (*cur_p, k, buf_sort);
731+ llama_token_data_array_partial_sort (*cur_p, k, buf_sort);
726732 pdata = buf_sort.data ();
727733 } else if (!cur_p->sorted ) {
728734 // small candidates -> sort inplace
729- llama_token_data_array_sort_inplace (cur_p, k);
730- cur_p->sorted = true ;
735+ llama_token_data_array_partial_sort_inplace (cur_p, k);
731736 }
732737
733738 // Compute the cumulative probabilities
@@ -747,7 +752,7 @@ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_d
747752 // we exceeded the current top-k heuristic -> increase k and continue
748753 if (!cur_p->sorted && i == k - 1 ) {
749754 k = cur_p->size ;
750- llama_token_data_array_sort (*cur_p, k, buf_sort);
755+ llama_token_data_array_partial_sort (*cur_p, k, buf_sort);
751756 pdata = buf_sort.data ();
752757 }
753758 }
@@ -838,8 +843,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
838843 if (!min_p_applied) {
839844 // Sort the logits in descending order
840845 if (!cur_p->sorted ) {
841- llama_token_data_array_sort_inplace (cur_p, cur_p->size );
842- cur_p->sorted = true ;
846+ llama_token_data_array_partial_sort_inplace (cur_p, cur_p->size );
843847 }
844848
845849 const float min_logit = cur_p->data [0 ].logit + logf (ctx->p ); // min logit for p_i >= p * p_max
0 commit comments