Skip to content

Commit d1f0210

Browse files
committed
examples : evaluate tokens in batches after swapping context
1 parent e9298af commit d1f0210

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

examples/main/main.cpp

+13-5
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ int main(int argc, char ** argv) {
264264
// infinite text generation via context swapping
265265
// if we run out of context:
266266
// - take the n_keep first tokens from the original prompt (via n_past)
267-
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
267+
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
268268
if (n_past + (int) embd.size() > n_ctx) {
269269
const int n_left = n_past - params.n_keep;
270270

@@ -282,13 +282,21 @@ int main(int argc, char ** argv) {
282282
//printf("\n---\n");
283283
}
284284

285-
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
286-
fprintf(stderr, "%s : failed to eval\n", __func__);
287-
return 1;
285+
// evaluate tokens in batches
286+
// embd is typically prepared beforehand to fit within a batch, but not always
287+
for (int i = 0; i < (int)embd.size(); i += params.n_batch) {
288+
int n_eval = (int)embd.size()-i;
289+
if (n_eval > params.n_batch) {
290+
n_eval = params.n_batch;
291+
}
292+
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
293+
fprintf(stderr, "%s : failed to eval\n", __func__);
294+
return 1;
295+
}
296+
n_past += n_eval;
288297
}
289298
}
290299

291-
n_past += embd.size();
292300
embd.clear();
293301

294302
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {

0 commit comments

Comments
 (0)