@@ -264,7 +264,7 @@ int main(int argc, char ** argv) {
264
264
// infinite text generation via context swapping
265
265
// if we run out of context:
266
266
// - take the n_keep first tokens from the original prompt (via n_past)
267
- // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
267
+ // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
268
268
if (n_past + (int ) embd.size () > n_ctx) {
269
269
const int n_left = n_past - params.n_keep ;
270
270
@@ -282,13 +282,21 @@ int main(int argc, char ** argv) {
282
282
// printf("\n---\n");
283
283
}
284
284
285
- if (llama_eval (ctx, embd.data (), embd.size (), n_past, params.n_threads )) {
286
- fprintf (stderr, " %s : failed to eval\n " , __func__);
287
- return 1 ;
285
+ // evaluate tokens in batches
286
+ // embd is typically prepared beforehand to fit within a batch, but not always
287
+ for (int i = 0 ; i < (int )embd.size (); i += params.n_batch ) {
288
+ int n_eval = (int )embd.size ()-i;
289
+ if (n_eval > params.n_batch ) {
290
+ n_eval = params.n_batch ;
291
+ }
292
+ if (llama_eval (ctx, &embd[i], n_eval, n_past, params.n_threads )) {
293
+ fprintf (stderr, " %s : failed to eval\n " , __func__);
294
+ return 1 ;
295
+ }
296
+ n_past += n_eval;
288
297
}
289
298
}
290
299
291
- n_past += embd.size ();
292
300
embd.clear ();
293
301
294
302
if ((int ) embd_inp.size () <= n_consumed && !is_interacting) {
0 commit comments