Skip to content

Commit 3bc0a89

Browse files
committed
Evaluate tokens in batches after swapping context
1 parent 489537e commit 3bc0a89

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

examples/main/main.cpp

+11-5
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ int main(int argc, char ** argv) {
253253
// infinite text generation via context swapping
254254
// if we run out of context:
255255
// - take the n_keep first tokens from the original prompt (via n_past)
256-
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
256+
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
257257
if (n_past + (int) embd.size() > n_ctx) {
258258
const int n_left = n_past - params.n_keep;
259259

@@ -271,13 +271,19 @@ int main(int argc, char ** argv) {
271271
//printf("\n---\n");
272272
}
273273

274-
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
275-
fprintf(stderr, "%s : failed to eval\n", __func__);
276-
return 1;
274+
for (int i = 0; i < (int)embd.size(); i += params.n_batch) {
275+
int n_eval = (int)embd.size()-i;
276+
if (n_eval > params.n_batch) {
277+
n_eval = params.n_batch;
278+
}
279+
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
280+
fprintf(stderr, "%s : failed to eval\n", __func__);
281+
return 1;
282+
}
283+
n_past += n_eval;
277284
}
278285
}
279286

280-
n_past += embd.size();
281287
embd.clear();
282288

283289
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {

0 commit comments

Comments
 (0)