@@ -257,12 +257,12 @@ int main(int argc, char ** argv) {
257
257
258
258
LOG (" prefix: \" %s\"\n " , log_tostr (params.input_prefix ));
259
259
LOG (" suffix: \" %s\"\n " , log_tostr (params.input_suffix ));
260
- LOG (" tokens: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd_inp));
260
+ LOG (" tokens: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd_inp). c_str () );
261
261
262
262
// Should not run without any tokens
263
263
if (embd_inp.empty ()) {
264
264
embd_inp.push_back (llama_token_bos (ctx));
265
- LOG (" embd_inp was considered empty and bos was added: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd_inp));
265
+ LOG (" embd_inp was considered empty and bos was added: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd_inp). c_str () );
266
266
}
267
267
268
268
// Tokenize negative prompt
@@ -273,10 +273,10 @@ int main(int argc, char ** argv) {
273
273
LOG (" cfg_negative_prompt: \" %s\"\n " , log_tostr (sparams.cfg_negative_prompt ));
274
274
275
275
guidance_inp = ::llama_tokenize (ctx_guidance, sparams.cfg_negative_prompt , add_bos);
276
- LOG (" guidance_inp tokenized: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx_guidance, guidance_inp));
276
+ LOG (" guidance_inp tokenized: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx_guidance, guidance_inp). c_str () );
277
277
278
278
std::vector<llama_token> original_inp = ::llama_tokenize (ctx, params.prompt , add_bos);
279
- LOG (" original_inp tokenized: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, original_inp));
279
+ LOG (" original_inp tokenized: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, original_inp). c_str () );
280
280
281
281
original_prompt_len = original_inp.size ();
282
282
guidance_offset = (int )guidance_inp.size () - original_prompt_len;
@@ -294,8 +294,8 @@ int main(int argc, char ** argv) {
294
294
params.n_keep = (int )embd_inp.size ();
295
295
}
296
296
297
- LOG (" inp_pfx: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, inp_pfx));
298
- LOG (" inp_sfx: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, inp_sfx));
297
+ LOG (" inp_pfx: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, inp_pfx). c_str () );
298
+ LOG (" inp_sfx: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, inp_sfx). c_str () );
299
299
300
300
301
301
// enable interactive mode if interactive start is specified
@@ -388,9 +388,6 @@ int main(int argc, char ** argv) {
388
388
grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
389
389
}
390
390
391
- // TODO: replace with ring-buffer
392
- std::vector<llama_token> last_tokens (n_ctx);
393
- std::fill (last_tokens.begin (), last_tokens.end (), 0 );
394
391
LOG_TEE (" \n ##### Infill mode #####\n\n " );
395
392
if (params.infill ) {
396
393
printf (" \n ************\n " );
@@ -433,11 +430,7 @@ int main(int argc, char ** argv) {
433
430
std::vector<llama_token> embd;
434
431
std::vector<llama_token> embd_guidance;
435
432
436
- const int n_vocab = llama_n_vocab (model);
437
-
438
- llama_sampling_context ctx_sampling = llama_sampling_context_init (params, grammar);
439
- std::vector<llama_token_data> candidates;
440
- candidates.reserve (n_vocab);
433
+ struct llama_sampling_context * ctx_sampling = llama_sampling_init (params);
441
434
442
435
while (n_remain != 0 || params.interactive ) {
443
436
// predict
@@ -484,7 +477,7 @@ int main(int argc, char ** argv) {
484
477
485
478
LOG (" after swap: n_past = %d, n_past_guidance = %d\n " , n_past, n_past_guidance);
486
479
487
- LOG (" embd: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd));
480
+ LOG (" embd: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd). c_str () );
488
481
489
482
}
490
483
@@ -512,7 +505,7 @@ int main(int argc, char ** argv) {
512
505
input_buf = embd_guidance.data ();
513
506
input_size = embd_guidance.size ();
514
507
515
- LOG (" guidance context: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd_guidance));
508
+ LOG (" guidance context: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd_guidance). c_str () );
516
509
} else {
517
510
input_buf = embd.data ();
518
511
input_size = embd.size ();
@@ -535,7 +528,7 @@ int main(int argc, char ** argv) {
535
528
n_eval = params.n_batch ;
536
529
}
537
530
538
- LOG (" eval: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd));
531
+ LOG (" eval: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd). c_str () );
539
532
540
533
if (llama_decode (ctx, llama_batch_get_one (&embd[i], n_eval, n_past, 0 ))) {
541
534
LOG_TEE (" %s : failed to eval\n " , __func__);
@@ -554,12 +547,11 @@ int main(int argc, char ** argv) {
554
547
555
548
if ((int ) embd_inp.size () <= n_consumed && !is_interacting) {
556
549
557
- const llama_token id = llama_sampling_sample (ctx, ctx_guidance, ctx_sampling, last_tokens, candidates );
550
+ const llama_token id = llama_sampling_sample (ctx_sampling, ctx, ctx_guidance );
558
551
559
- last_tokens.erase (last_tokens.begin ());
560
- last_tokens.push_back (id);
552
+ llama_sampling_accept (ctx_sampling, ctx, id);
561
553
562
- LOG (" last: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, last_tokens ));
554
+ LOG (" last: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, ctx_sampling-> prev ). c_str ( ));
563
555
564
556
embd.push_back (id);
565
557
@@ -575,8 +567,8 @@ int main(int argc, char ** argv) {
575
567
LOG (" embd_inp.size(): %d, n_consumed: %d\n " , (int ) embd_inp.size (), n_consumed);
576
568
while ((int ) embd_inp.size () > n_consumed) {
577
569
embd.push_back (embd_inp[n_consumed]);
578
- last_tokens .erase (last_tokens .begin ());
579
- last_tokens .push_back (embd_inp[n_consumed]);
570
+ ctx_sampling-> prev .erase (ctx_sampling-> prev .begin ());
571
+ ctx_sampling-> prev .push_back (embd_inp[n_consumed]);
580
572
++n_consumed;
581
573
if ((int ) embd.size () >= params.n_batch ) {
582
574
break ;
@@ -608,7 +600,7 @@ int main(int argc, char ** argv) {
608
600
if ((int ) embd_inp.size () <= n_consumed) {
609
601
610
602
// deal with eot token in infill mode
611
- if ((last_tokens .back () == llama_token_eot (ctx) || is_interacting) && params.interactive ){
603
+ if ((ctx_sampling-> prev .back () == llama_token_eot (ctx) || is_interacting) && params.interactive ){
612
604
if (is_interacting && !params.interactive_first ) {
613
605
// print an eot token
614
606
printf (" %s" , llama_token_to_piece (ctx, llama_token_eot (ctx)).c_str ());
@@ -675,7 +667,7 @@ int main(int argc, char ** argv) {
675
667
is_interacting = false ;
676
668
}
677
669
// deal with end of text token in interactive mode
678
- else if (last_tokens .back () == llama_token_eos (ctx)) {
670
+ else if (ctx_sampling-> prev .back () == llama_token_eos (ctx)) {
679
671
LOG (" found EOS token\n " );
680
672
681
673
if (params.interactive ) {
@@ -727,7 +719,7 @@ int main(int argc, char ** argv) {
727
719
const size_t original_size = embd_inp.size ();
728
720
729
721
const auto line_inp = ::llama_tokenize (ctx, buffer, false );
730
- LOG (" input tokens: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, line_inp));
722
+ LOG (" input tokens: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, line_inp). c_str () );
731
723
732
724
embd_inp.insert (embd_inp.end (), line_inp.begin (), line_inp.end ());
733
725
0 commit comments