@@ -109,10 +109,16 @@ int main(int argc, char ** argv) {
109
109
110
110
llama_model * model;
111
111
llama_context * ctx;
112
+ llama_context * ctx_guidance = NULL ;
112
113
g_ctx = &ctx;
113
114
114
115
// load the model and apply lora adapter, if any
115
116
std::tie (model, ctx) = llama_init_from_gpt_params (params);
117
+ if (params.cfg_scale > 1 .f ) {
118
+ struct llama_context_params lparams = llama_context_params_from_gpt_params (params);
119
+ ctx_guidance = llama_new_context_with_model (model, lparams);
120
+ }
121
+
116
122
if (model == NULL ) {
117
123
fprintf (stderr, " %s: error: unable to load model\n " , __func__);
118
124
return 1 ;
@@ -183,15 +189,28 @@ int main(int argc, char ** argv) {
183
189
// tokenize the prompt
184
190
std::vector<llama_token> embd_inp;
185
191
186
- if (params.interactive_first || params.instruct || !params.prompt .empty () || session_tokens.empty ()) {
187
- // Add a space in front of the first character to match OG llama tokenizer behavior
188
- params.prompt .insert (0 , 1 , ' ' );
192
+ // Add a space in front of the first character to match OG llama tokenizer behavior
193
+ params.prompt .insert (0 , 1 , ' ' );
189
194
195
+ if (params.interactive_first || params.instruct || !params.prompt .empty () || session_tokens.empty ()) {
190
196
embd_inp = ::llama_tokenize (ctx, params.prompt , true );
191
197
} else {
192
198
embd_inp = session_tokens;
193
199
}
194
200
201
+ // Tokenize negative prompt
202
+ std::vector<llama_token> guidance_inp;
203
+ int guidance_offset = 0 ;
204
+ int original_prompt_len = 0 ;
205
+ if (ctx_guidance) {
206
+ params.cfg_negative_prompt .insert (0 , 1 , ' ' );
207
+ guidance_inp = ::llama_tokenize (ctx_guidance, params.cfg_negative_prompt , true );
208
+
209
+ std::vector<llama_token> original_inp = ::llama_tokenize (ctx, params.prompt , true );
210
+ original_prompt_len = original_inp.size ();
211
+ guidance_offset = (int )guidance_inp.size () - original_prompt_len;
212
+ }
213
+
195
214
const int n_ctx = llama_n_ctx (ctx);
196
215
197
216
if ((int ) embd_inp.size () > n_ctx - 4 ) {
@@ -258,6 +277,16 @@ int main(int argc, char ** argv) {
258
277
for (int i = 0 ; i < (int ) embd_inp.size (); i++) {
259
278
fprintf (stderr, " %6d -> '%s'\n " , embd_inp[i], llama_token_to_str (ctx, embd_inp[i]));
260
279
}
280
+
281
+ if (ctx_guidance) {
282
+ fprintf (stderr, " \n " );
283
+ fprintf (stderr, " %s: negative prompt: '%s'\n " , __func__, params.cfg_negative_prompt .c_str ());
284
+ fprintf (stderr, " %s: number of tokens in negative prompt = %zu\n " , __func__, guidance_inp.size ());
285
+ for (int i = 0 ; i < (int ) guidance_inp.size (); i++) {
286
+ fprintf (stderr, " %6d -> '%s'\n " , guidance_inp[i], llama_token_to_str (ctx, guidance_inp[i]));
287
+ }
288
+ }
289
+
261
290
if (params.n_keep > 0 ) {
262
291
fprintf (stderr, " %s: static prompt based on n_keep: '" , __func__);
263
292
for (int i = 0 ; i < params.n_keep ; i++) {
@@ -334,11 +363,13 @@ int main(int argc, char ** argv) {
334
363
int n_remain = params.n_predict ;
335
364
int n_consumed = 0 ;
336
365
int n_session_consumed = 0 ;
366
+ int n_past_guidance = 0 ;
337
367
338
368
// the first thing we will do is to output the prompt, so set color accordingly
339
369
console_set_color (con_st, CONSOLE_COLOR_PROMPT);
340
370
341
371
std::vector<llama_token> embd;
372
+ std::vector<llama_token> embd_guidance;
342
373
343
374
// do one empty run to warm up the model
344
375
{
@@ -367,11 +398,12 @@ int main(int argc, char ** argv) {
367
398
// if we run out of context:
368
399
// - take the n_keep first tokens from the original prompt (via n_past)
369
400
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
370
- if (n_past + (int ) embd.size () > n_ctx) {
401
+ if (n_past + (int ) embd.size () + std::max< int >( 0 , guidance_offset) > n_ctx) {
371
402
const int n_left = n_past - params.n_keep ;
372
403
373
404
// always keep the first token - BOS
374
405
n_past = std::max (1 , params.n_keep );
406
+ n_past_guidance = std::max (1 , params.n_keep + guidance_offset);
375
407
376
408
// insert n_left/2 tokens at the start of embd from last_n_tokens
377
409
embd.insert (embd.begin (), last_n_tokens.begin () + n_ctx - n_left/2 - embd.size (), last_n_tokens.end () - embd.size ());
@@ -412,6 +444,48 @@ int main(int argc, char ** argv) {
412
444
413
445
// evaluate tokens in batches
414
446
// embd is typically prepared beforehand to fit within a batch, but not always
447
+
448
+ if (ctx_guidance) {
449
+ int input_size = 0 ;
450
+ llama_token* input_buf = NULL ;
451
+
452
+ if (n_past_guidance < (int ) guidance_inp.size ()) {
453
+ // Guidance context should have the same data with these modifications:
454
+ //
455
+ // * Replace the initial prompt
456
+ // * Shift everything by guidance_offset
457
+ embd_guidance = guidance_inp;
458
+ if (embd.begin () + original_prompt_len < embd.end ()) {
459
+ embd_guidance.insert (
460
+ embd_guidance.end (),
461
+ embd.begin () + original_prompt_len,
462
+ embd.end ()
463
+ );
464
+ }
465
+
466
+ input_buf = embd_guidance.data ();
467
+ input_size = embd_guidance.size ();
468
+ // fprintf(stderr, "\n---------------------\n");
469
+ // for (int i = 0; i < (int) embd_guidance.size(); i++) {
470
+ // fprintf(stderr, "%s", llama_token_to_str(ctx, embd_guidance[i]));
471
+ // }
472
+ // fprintf(stderr, "\n---------------------\n");
473
+ } else {
474
+ input_buf = embd.data ();
475
+ input_size = embd.size ();
476
+ }
477
+
478
+ for (int i = 0 ; i < input_size; i += params.n_batch ) {
479
+ int n_eval = std::min (input_size - i, params.n_batch );
480
+ if (llama_eval (ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads )) {
481
+ fprintf (stderr, " %s : failed to eval\n " , __func__);
482
+ return 1 ;
483
+ }
484
+
485
+ n_past_guidance += n_eval;
486
+ }
487
+ }
488
+
415
489
for (int i = 0 ; i < (int ) embd.size (); i += params.n_batch ) {
416
490
int n_eval = (int ) embd.size () - i;
417
491
if (n_eval > params.n_batch ) {
@@ -431,6 +505,7 @@ int main(int argc, char ** argv) {
431
505
}
432
506
433
507
embd.clear ();
508
+ embd_guidance.clear ();
434
509
435
510
if ((int ) embd_inp.size () <= n_consumed && !is_interacting) {
436
511
// out of user input, sample next token
@@ -473,6 +548,10 @@ int main(int argc, char ** argv) {
473
548
474
549
llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
475
550
551
+ if (ctx_guidance) {
552
+ llama_sample_classifier_free_guidance (ctx, &candidates_p, ctx_guidance, params.cfg_scale , params.cfg_smooth_factor );
553
+ }
554
+
476
555
// Apply penalties
477
556
float nl_logit = logits[llama_token_nl ()];
478
557
auto last_n_repeat = std::min (std::min ((int )last_n_tokens.size (), repeat_last_n), n_ctx);
@@ -668,6 +747,7 @@ int main(int argc, char ** argv) {
668
747
}
669
748
670
749
llama_print_timings (ctx);
750
+ if (ctx_guidance) { llama_free (ctx_guidance); }
671
751
llama_free (ctx);
672
752
llama_free_model (model);
673
753
0 commit comments