@@ -253,7 +253,7 @@ int main(int argc, char ** argv) {
253
253
// infinite text generation via context swapping
254
254
// if we run out of context:
255
255
// - take the n_keep first tokens from the original prompt (via n_past)
256
- // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
256
+ // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
257
257
if (n_past + (int ) embd.size () > n_ctx) {
258
258
const int n_left = n_past - params.n_keep ;
259
259
@@ -271,13 +271,19 @@ int main(int argc, char ** argv) {
271
271
// printf("\n---\n");
272
272
}
273
273
274
- if (llama_eval (ctx, embd.data (), embd.size (), n_past, params.n_threads )) {
275
- fprintf (stderr, " %s : failed to eval\n " , __func__);
276
- return 1 ;
274
+ for (int i = 0 ; i < (int )embd.size (); i += params.n_batch ) {
275
+ int n_eval = (int )embd.size ()-i;
276
+ if (n_eval > params.n_batch ) {
277
+ n_eval = params.n_batch ;
278
+ }
279
+ if (llama_eval (ctx, &embd[i], n_eval, n_past, params.n_threads )) {
280
+ fprintf (stderr, " %s : failed to eval\n " , __func__);
281
+ return 1 ;
282
+ }
283
+ n_past += n_eval;
277
284
}
278
285
}
279
286
280
- n_past += embd.size ();
281
287
embd.clear ();
282
288
283
289
if ((int ) embd_inp.size () <= n_consumed && !is_interacting) {
0 commit comments