@@ -128,6 +128,77 @@ struct ring_buffer {
128128 std::vector<T> data;
129129};
130130
131+ static void llama_token_data_array_sort (const llama_token_data_array * cur_p, int k, std::vector<llama_token_data> & res) {
132+ static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
133+ return a.logit > b.logit ;
134+ };
135+
136+ constexpr int nbuckets = 128 ;
137+ constexpr float bucket_low = -10 .0f ;
138+ constexpr float bucket_high = 10 .0f ;
139+ constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
140+ constexpr float bucket_inter = -bucket_low * bucket_scale;
141+
142+ std::vector<int > bucket_idx (cur_p->size );
143+ std::vector<int > histo (nbuckets, 0 );
144+
145+ for (int i = 0 ; i < (int )cur_p->size ; ++i) {
146+ const float val = cur_p->data [i].logit ;
147+ int ib = int (bucket_scale * val + bucket_inter); // nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
148+ ib = std::max (0 , std::min (nbuckets - 1 , ib));
149+ bucket_idx[i] = ib;
150+ ++histo[ib];
151+ }
152+ int nhave = 0 ;
153+ int ib = nbuckets - 1 ;
154+ for ( ; ib >= 0 ; --ib) {
155+ nhave += histo[ib];
156+ if (nhave >= k) {
157+ break ;
158+ }
159+ }
160+ res.resize (nhave);
161+ auto * ptr = res.data ();
162+ std::vector<llama_token_data*> bucket_ptrs;
163+ bucket_ptrs.reserve (nbuckets - ib);
164+ for (int j = nbuckets - 1 ; j >= ib; --j) {
165+ bucket_ptrs.push_back (ptr);
166+ ptr += histo[j];
167+ }
168+ for (int i = 0 ; i < (int )cur_p->size ; ++i) {
169+ int j = bucket_idx[i];
170+ if (j >= ib) {
171+ *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data [i];
172+ }
173+ }
174+
175+ ptr = res.data ();
176+ int ndone = 0 ;
177+ for (int j = nbuckets - 1 ; j > ib; --j) {
178+ std::sort (ptr, ptr + histo[j], comp);
179+ ptr += histo[j];
180+ ndone += histo[j];
181+ }
182+ std::partial_sort (ptr, ptr + k - ndone, ptr + histo[ib], comp);
183+ }
184+
185+ static void llama_token_data_array_sort (llama_token_data_array * cur_p, int k) {
186+ static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
187+ return a.logit > b.logit ;
188+ };
189+
190+ if (k <= 128 ) {
191+ std::partial_sort (cur_p->data , cur_p->data + k, cur_p->data + cur_p->size , comp);
192+ return ;
193+ }
194+
195+ std::vector<llama_token_data> tmp_tokens;
196+
197+ llama_token_data_array_sort (cur_p, k, tmp_tokens);
198+
199+ std::memcpy (cur_p->data , tmp_tokens.data (), k*sizeof (llama_token_data));
200+ }
201+
131202static int llama_sample_dist (llama_token_data_array * cur_p, std::mt19937 & rng) {
132203 // iterator for the probabilities
133204#ifdef __GNUC__
@@ -200,18 +271,22 @@ static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp)
200271 }
201272}
202273
203- static void llama_sampler_softmax_impl (llama_token_data_array * cur_p) {
274+ static void llama_sampler_softmax_impl (llama_token_data_array * cur_p, bool do_sort = true ) {
204275 GGML_ASSERT (cur_p->size > 0 );
205276
206- // Sort the logits in descending order
207- if (!cur_p->sorted ) {
208- std::sort (cur_p->data , cur_p->data + cur_p->size , [](const llama_token_data & a, const llama_token_data & b) {
209- return a.logit > b.logit ;
210- });
277+ // Sort the logits in descending order if requested
278+ if (do_sort && !cur_p->sorted ) {
279+ llama_token_data_array_sort (cur_p, cur_p->size );
211280 cur_p->sorted = true ;
212281 }
213282
214283 float max_l = cur_p->data [0 ].logit ;
284+ if (!cur_p->sorted ) {
285+ for (size_t i = 1 ; i < cur_p->size ; ++i) {
286+ max_l = std::max (max_l, cur_p->data [i].logit );
287+ }
288+ }
289+
215290 float cum_sum = 0 .0f ;
216291
217292 for (size_t i = 0 ; i < cur_p->size ; ++i) {
@@ -226,7 +301,6 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
226301}
227302
228303static void llama_sampler_top_k_impl (llama_token_data_array * cur_p, int32_t k) {
229- // TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast
230304 // if (k >= (int32_t)cur_p->size) {
231305 // return;
232306 // }
@@ -239,63 +313,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
239313
240314 // Sort scores in descending order
241315 if (!cur_p->sorted ) {
242- auto comp = [](const llama_token_data & a, const llama_token_data & b) {
243- return a.logit > b.logit ;
244- };
245- if (k <= 128 ) {
246- std::partial_sort (cur_p->data , cur_p->data + k, cur_p->data + cur_p->size , comp);
247- } else {
248- constexpr int nbuckets = 128 ;
249- constexpr float bucket_low = -10 .0f ;
250- constexpr float bucket_high = 10 .0f ;
251- constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
252- constexpr float bucket_inter = -bucket_low * bucket_scale;
253-
254- std::vector<int > bucket_idx (cur_p->size );
255- std::vector<int > histo (nbuckets, 0 );
256-
257- for (int i = 0 ; i < (int )cur_p->size ; ++i) {
258- const float val = cur_p->data [i].logit ;
259- int ib = int (bucket_scale * val + bucket_inter); // nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
260- ib = std::max (0 , std::min (nbuckets - 1 , ib));
261- bucket_idx[i] = ib;
262- ++histo[ib];
263- }
264- int nhave = 0 ;
265- int ib = nbuckets - 1 ;
266- for ( ; ib >= 0 ; --ib) {
267- nhave += histo[ib];
268- if (nhave >= k) {
269- break ;
270- }
271- }
272- std::vector<llama_token_data> tmp_tokens (nhave);
273- auto * ptr = tmp_tokens.data ();
274- std::vector<llama_token_data*> bucket_ptrs;
275- bucket_ptrs.reserve (nbuckets - ib);
276- for (int j = nbuckets - 1 ; j >= ib; --j) {
277- bucket_ptrs.push_back (ptr);
278- ptr += histo[j];
279- }
280- for (int i = 0 ; i < (int )cur_p->size ; ++i) {
281- int j = bucket_idx[i];
282- if (j >= ib) {
283- *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data [i];
284- }
285- }
286-
287- ptr = tmp_tokens.data ();
288- int ndone = 0 ;
289- for (int j = nbuckets - 1 ; j > ib; --j) {
290- std::sort (ptr, ptr + histo[j], comp);
291- ptr += histo[j];
292- ndone += histo[j];
293- }
294- std::partial_sort (ptr, ptr + k - ndone, ptr + histo[ib], comp);
295-
296- std::memcpy (cur_p->data , tmp_tokens.data (), k*sizeof (llama_token_data));
297-
298- }
316+ llama_token_data_array_sort (cur_p, k);
299317 cur_p->sorted = true ;
300318 }
301319
@@ -576,7 +594,8 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
576594static void llama_sampler_dist_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
577595 auto * ctx = (llama_sampler_dist *) smpl->ctx ;
578596
579- llama_sampler_softmax_impl (cur_p);
597+ // sorting is not necessary here, but for now we are doing it
598+ llama_sampler_softmax_impl (cur_p, true );
580599
581600 cur_p->selected = llama_sample_dist (cur_p, ctx->rng );
582601}
@@ -626,32 +645,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
626645 );
627646}
628647
629- // softmax
630-
631- static const char * llama_sampler_softmax_name (const struct llama_sampler * /* smpl*/ ) {
632- return " softmax" ;
633- }
634-
635- static void llama_sampler_softmax_apply (struct llama_sampler * /* smpl*/ , llama_token_data_array * cur_p) {
636- llama_sampler_softmax_impl (cur_p);
637- }
638-
639- static struct llama_sampler_i llama_sampler_softmax_i = {
640- /* .name = */ llama_sampler_softmax_name,
641- /* .accept = */ nullptr ,
642- /* .apply = */ llama_sampler_softmax_apply,
643- /* .reset = */ nullptr ,
644- /* .clone = */ nullptr ,
645- /* .free = */ nullptr ,
646- };
647-
648- struct llama_sampler * llama_sampler_init_softmax () {
649- return llama_sampler_init (
650- /* .iface = */ &llama_sampler_softmax_i,
651- /* .ctx = */ nullptr
652- );
653- }
654-
655648// top-k
656649
657650struct llama_sampler_top_k {
@@ -699,37 +692,67 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
699692struct llama_sampler_top_p {
700693 const float p;
701694 const size_t min_keep;
695+
696+ std::vector<llama_token_data> buf_sort;
702697};
703698
704699static const char * llama_sampler_top_p_name (const struct llama_sampler * /* smpl*/ ) {
705700 return " top-p" ;
706701}
707702
708703static void llama_sampler_top_p_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
709- const auto * ctx = (llama_sampler_top_p *) smpl->ctx ;
704+ auto * ctx = (llama_sampler_top_p *) smpl->ctx ;
710705
711706 if (ctx->p >= 1 .0f ) {
712707 return ;
713708 }
714709
715- llama_sampler_softmax_impl (cur_p);
710+ llama_sampler_softmax_impl (cur_p, false );
711+
712+ size_t k = cur_p->size ;
713+ auto * pdata = cur_p->data ;
714+
715+ auto & buf_sort = ctx->buf_sort ;
716+
717+ // if not sorted, try adaptive top-k sorting
718+ if (!cur_p->sorted && cur_p->size > 1024 ) {
719+ k = std::min<size_t >(256 , cur_p->size );
720+ llama_token_data_array_sort (cur_p, k, buf_sort);
721+ pdata = buf_sort.data ();
722+ } else if (!cur_p->sorted ) {
723+ // small candidates -> sort inplace
724+ llama_token_data_array_sort (cur_p, k);
725+ cur_p->sorted = true ;
726+ }
716727
717728 // Compute the cumulative probabilities
718729 float cum_sum = 0 .0f ;
719730 size_t last_idx = cur_p->size ;
720731
721732 for (size_t i = 0 ; i < cur_p->size ; ++i) {
722- cum_sum += cur_p-> data [i].p ;
733+ cum_sum += pdata [i].p ;
723734
724735 // Check if the running sum is at least p or if we have kept at least min_keep tokens
725736 // we set the last index to i+1 to indicate that the current iterate should be included in the set
726737 if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep ) {
727738 last_idx = i + 1 ;
728739 break ;
729740 }
741+
742+ // we exceeded the current top-k heuristic -> increase k and continue
743+ if (!cur_p->sorted && i == k - 1 ) {
744+ k = cur_p->size ;
745+ llama_token_data_array_sort (cur_p, k, buf_sort);
746+ pdata = buf_sort.data ();
747+ }
730748 }
731749
732750 // Resize the output vector to keep only the top-p tokens
751+ if (!cur_p->sorted ) {
752+ std::memcpy (cur_p->data , buf_sort.data (), last_idx*sizeof (llama_token_data));
753+ cur_p->sorted = true ;
754+ }
755+
733756 cur_p->size = last_idx;
734757}
735758
@@ -757,6 +780,7 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
757780 /* .ctx = */ new llama_sampler_top_p {
758781 /* .p = */ p,
759782 /* .min_keep = */ min_keep,
783+ /* .buf_sort = */ {},
760784 }
761785 );
762786}
@@ -809,9 +833,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
809833 if (!min_p_applied) {
810834 // Sort the logits in descending order
811835 if (!cur_p->sorted ) {
812- std::sort (cur_p->data , cur_p->data + cur_p->size , [](const llama_token_data & a, const llama_token_data & b) {
813- return a.logit > b.logit ;
814- });
836+ llama_token_data_array_sort (cur_p, cur_p->size );
815837 cur_p->sorted = true ;
816838 }
817839
0 commit comments