Skip to content

Commit 0b57c88

Browse files
committed
sampling : optimize sorting using bucket sort in more places
ggml-ci
1 parent 009b709 commit 0b57c88

File tree

3 files changed

+120
-103
lines changed

3 files changed

+120
-103
lines changed

include/llama.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,11 +1148,6 @@ extern "C" {
11481148
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
11491149
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
11501150

1151-
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
1152-
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
1153-
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
1154-
"will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)");
1155-
11561151
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
11571152
/// Setting k <= 0 makes this a noop
11581153
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);

src/llama-sampling.cpp

Lines changed: 119 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,77 @@ struct ring_buffer {
128128
std::vector<T> data;
129129
};
130130

131+
static void llama_token_data_array_sort(const llama_token_data_array * cur_p, int k, std::vector<llama_token_data> & res) {
132+
static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
133+
return a.logit > b.logit;
134+
};
135+
136+
constexpr int nbuckets = 128;
137+
constexpr float bucket_low = -10.0f;
138+
constexpr float bucket_high = 10.0f;
139+
constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
140+
constexpr float bucket_inter = -bucket_low * bucket_scale;
141+
142+
std::vector<int> bucket_idx(cur_p->size);
143+
std::vector<int> histo(nbuckets, 0);
144+
145+
for (int i = 0; i < (int)cur_p->size; ++i) {
146+
const float val = cur_p->data[i].logit;
147+
int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
148+
ib = std::max(0, std::min(nbuckets - 1, ib));
149+
bucket_idx[i] = ib;
150+
++histo[ib];
151+
}
152+
int nhave = 0;
153+
int ib = nbuckets - 1;
154+
for ( ; ib >= 0; --ib) {
155+
nhave += histo[ib];
156+
if (nhave >= k) {
157+
break;
158+
}
159+
}
160+
res.resize(nhave);
161+
auto * ptr = res.data();
162+
std::vector<llama_token_data*> bucket_ptrs;
163+
bucket_ptrs.reserve(nbuckets - ib);
164+
for (int j = nbuckets - 1; j >= ib; --j) {
165+
bucket_ptrs.push_back(ptr);
166+
ptr += histo[j];
167+
}
168+
for (int i = 0; i < (int)cur_p->size; ++i) {
169+
int j = bucket_idx[i];
170+
if (j >= ib) {
171+
*bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i];
172+
}
173+
}
174+
175+
ptr = res.data();
176+
int ndone = 0;
177+
for (int j = nbuckets - 1; j > ib; --j) {
178+
std::sort(ptr, ptr + histo[j], comp);
179+
ptr += histo[j];
180+
ndone += histo[j];
181+
}
182+
std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
183+
}
184+
185+
static void llama_token_data_array_sort(llama_token_data_array * cur_p, int k) {
186+
static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
187+
return a.logit > b.logit;
188+
};
189+
190+
if (k <= 128) {
191+
std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
192+
return;
193+
}
194+
195+
std::vector<llama_token_data> tmp_tokens;
196+
197+
llama_token_data_array_sort(cur_p, k, tmp_tokens);
198+
199+
std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
200+
}
201+
131202
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
132203
// iterator for the probabilities
133204
#ifdef __GNUC__
@@ -200,18 +271,22 @@ static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp)
200271
}
201272
}
202273

203-
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
274+
static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_sort = true) {
204275
GGML_ASSERT(cur_p->size > 0);
205276

206-
// Sort the logits in descending order
207-
if (!cur_p->sorted) {
208-
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
209-
return a.logit > b.logit;
210-
});
277+
// Sort the logits in descending order if requested
278+
if (do_sort && !cur_p->sorted) {
279+
llama_token_data_array_sort(cur_p, cur_p->size);
211280
cur_p->sorted = true;
212281
}
213282

214283
float max_l = cur_p->data[0].logit;
284+
if (!cur_p->sorted) {
285+
for (size_t i = 1; i < cur_p->size; ++i) {
286+
max_l = std::max(max_l, cur_p->data[i].logit);
287+
}
288+
}
289+
215290
float cum_sum = 0.0f;
216291

217292
for (size_t i = 0; i < cur_p->size; ++i) {
@@ -226,7 +301,6 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
226301
}
227302

228303
static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
229-
// TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast
230304
// if (k >= (int32_t)cur_p->size) {
231305
// return;
232306
// }
@@ -239,63 +313,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
239313

240314
// Sort scores in descending order
241315
if (!cur_p->sorted) {
242-
auto comp = [](const llama_token_data & a, const llama_token_data & b) {
243-
return a.logit > b.logit;
244-
};
245-
if (k <= 128) {
246-
std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
247-
} else {
248-
constexpr int nbuckets = 128;
249-
constexpr float bucket_low = -10.0f;
250-
constexpr float bucket_high = 10.0f;
251-
constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
252-
constexpr float bucket_inter = -bucket_low * bucket_scale;
253-
254-
std::vector<int> bucket_idx(cur_p->size);
255-
std::vector<int> histo(nbuckets, 0);
256-
257-
for (int i = 0; i < (int)cur_p->size; ++i) {
258-
const float val = cur_p->data[i].logit;
259-
int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
260-
ib = std::max(0, std::min(nbuckets - 1, ib));
261-
bucket_idx[i] = ib;
262-
++histo[ib];
263-
}
264-
int nhave = 0;
265-
int ib = nbuckets - 1;
266-
for ( ; ib >= 0; --ib) {
267-
nhave += histo[ib];
268-
if (nhave >= k) {
269-
break;
270-
}
271-
}
272-
std::vector<llama_token_data> tmp_tokens(nhave);
273-
auto * ptr = tmp_tokens.data();
274-
std::vector<llama_token_data*> bucket_ptrs;
275-
bucket_ptrs.reserve(nbuckets - ib);
276-
for (int j = nbuckets - 1; j >= ib; --j) {
277-
bucket_ptrs.push_back(ptr);
278-
ptr += histo[j];
279-
}
280-
for (int i = 0; i < (int)cur_p->size; ++i) {
281-
int j = bucket_idx[i];
282-
if (j >= ib) {
283-
*bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i];
284-
}
285-
}
286-
287-
ptr = tmp_tokens.data();
288-
int ndone = 0;
289-
for (int j = nbuckets - 1; j > ib; --j) {
290-
std::sort(ptr, ptr + histo[j], comp);
291-
ptr += histo[j];
292-
ndone += histo[j];
293-
}
294-
std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
295-
296-
std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
297-
298-
}
316+
llama_token_data_array_sort(cur_p, k);
299317
cur_p->sorted = true;
300318
}
301319

@@ -576,7 +594,8 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
576594
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
577595
auto * ctx = (llama_sampler_dist *) smpl->ctx;
578596

579-
llama_sampler_softmax_impl(cur_p);
597+
// sorting is not necessary here, but for now we are doing it
598+
llama_sampler_softmax_impl(cur_p, true);
580599

581600
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
582601
}
@@ -626,32 +645,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
626645
);
627646
}
628647

629-
// softmax
630-
631-
static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
632-
return "softmax";
633-
}
634-
635-
static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
636-
llama_sampler_softmax_impl(cur_p);
637-
}
638-
639-
static struct llama_sampler_i llama_sampler_softmax_i = {
640-
/* .name = */ llama_sampler_softmax_name,
641-
/* .accept = */ nullptr,
642-
/* .apply = */ llama_sampler_softmax_apply,
643-
/* .reset = */ nullptr,
644-
/* .clone = */ nullptr,
645-
/* .free = */ nullptr,
646-
};
647-
648-
struct llama_sampler * llama_sampler_init_softmax() {
649-
return llama_sampler_init(
650-
/* .iface = */ &llama_sampler_softmax_i,
651-
/* .ctx = */ nullptr
652-
);
653-
}
654-
655648
// top-k
656649

657650
struct llama_sampler_top_k {
@@ -699,37 +692,67 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
699692
struct llama_sampler_top_p {
700693
const float p;
701694
const size_t min_keep;
695+
696+
std::vector<llama_token_data> buf_sort;
702697
};
703698

704699
static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
705700
return "top-p";
706701
}
707702

708703
static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
709-
const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
704+
auto * ctx = (llama_sampler_top_p *) smpl->ctx;
710705

711706
if (ctx->p >= 1.0f) {
712707
return;
713708
}
714709

715-
llama_sampler_softmax_impl(cur_p);
710+
llama_sampler_softmax_impl(cur_p, false);
711+
712+
size_t k = cur_p->size;
713+
auto * pdata = cur_p->data;
714+
715+
auto & buf_sort = ctx->buf_sort;
716+
717+
// if not sorted, try adaptive top-k sorting
718+
if (!cur_p->sorted && cur_p->size > 1024) {
719+
k = std::min<size_t>(256, cur_p->size);
720+
llama_token_data_array_sort(cur_p, k, buf_sort);
721+
pdata = buf_sort.data();
722+
} else if (!cur_p->sorted) {
723+
// small candidates -> sort inplace
724+
llama_token_data_array_sort(cur_p, k);
725+
cur_p->sorted = true;
726+
}
716727

717728
// Compute the cumulative probabilities
718729
float cum_sum = 0.0f;
719730
size_t last_idx = cur_p->size;
720731

721732
for (size_t i = 0; i < cur_p->size; ++i) {
722-
cum_sum += cur_p->data[i].p;
733+
cum_sum += pdata[i].p;
723734

724735
// Check if the running sum is at least p or if we have kept at least min_keep tokens
725736
// we set the last index to i+1 to indicate that the current iterate should be included in the set
726737
if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
727738
last_idx = i + 1;
728739
break;
729740
}
741+
742+
// we exceeded the current top-k heuristic -> increase k and continue
743+
if (!cur_p->sorted && i == k - 1) {
744+
k = cur_p->size;
745+
llama_token_data_array_sort(cur_p, k, buf_sort);
746+
pdata = buf_sort.data();
747+
}
730748
}
731749

732750
// Resize the output vector to keep only the top-p tokens
751+
if (!cur_p->sorted) {
752+
std::memcpy(cur_p->data, buf_sort.data(), last_idx*sizeof(llama_token_data));
753+
cur_p->sorted = true;
754+
}
755+
733756
cur_p->size = last_idx;
734757
}
735758

@@ -757,6 +780,7 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
757780
/* .ctx = */ new llama_sampler_top_p {
758781
/* .p = */ p,
759782
/* .min_keep = */ min_keep,
783+
/* .buf_sort = */ {},
760784
}
761785
);
762786
}
@@ -809,9 +833,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
809833
if (!min_p_applied) {
810834
// Sort the logits in descending order
811835
if (!cur_p->sorted) {
812-
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
813-
return a.logit > b.logit;
814-
});
836+
llama_token_data_array_sort(cur_p, cur_p->size);
815837
cur_p->sorted = true;
816838
}
817839

tests/test-sampling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ int main(void) {
372372
test_sampler_queue(10000, "m", 10000, 1.0f, 1e-12);
373373

374374
test_sampler_queue(10000, "k", 100, 1.0000f, 1.0f);
375-
test_sampler_queue(10000, "p", 10000, 0.0002f, 1.0f);
375+
test_sampler_queue(10000, "p", 10000, 0.0003f, 1.0f);
376376
test_sampler_queue(10000, "p", 10000, 0.8000f, 1.0f);
377377
test_sampler_queue(10000, "m", 10000, 1.0000f, 9997.9f/9999.0f);
378378
test_sampler_queue(10000, "m", 10000, 1.0000f, 0.1f);

0 commit comments

Comments
 (0)