@@ -317,6 +317,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
317
317
break ;
318
318
}
319
319
params.n_chunks = std::stoi (argv[i]);
320
+ } else if (arg == " -np" || arg == " --parallel" ) {
321
+ if (++i >= argc) {
322
+ invalid_param = true ;
323
+ break ;
324
+ }
325
+ params.n_parallel = std::stoi (argv[i]);
326
+ } else if (arg == " -ns" || arg == " --sequences" ) {
327
+ if (++i >= argc) {
328
+ invalid_param = true ;
329
+ break ;
330
+ }
331
+ params.n_sequences = std::stoi (argv[i]);
320
332
} else if (arg == " -m" || arg == " --model" ) {
321
333
if (++i >= argc) {
322
334
invalid_param = true ;
@@ -360,6 +372,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
360
372
params.multiline_input = true ;
361
373
} else if (arg == " --simple-io" ) {
362
374
params.simple_io = true ;
375
+ } else if (arg == " -cb" || arg == " --cont-batching" ) {
376
+ params.cont_batching = true ;
363
377
} else if (arg == " --color" ) {
364
378
params.use_color = true ;
365
379
} else if (arg == " --mlock" ) {
@@ -436,8 +450,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
436
450
params.use_mmap = false ;
437
451
} else if (arg == " --numa" ) {
438
452
params.numa = true ;
439
- } else if (arg == " --export" ) {
440
- params.export_cgraph = true ;
441
453
} else if (arg == " --verbose-prompt" ) {
442
454
params.verbose_prompt = true ;
443
455
} else if (arg == " -r" || arg == " --reverse-prompt" ) {
@@ -456,8 +468,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
456
468
if (params.logdir .back () != DIRECTORY_SEPARATOR) {
457
469
params.logdir += DIRECTORY_SEPARATOR;
458
470
}
459
- } else if (arg == " --perplexity" ) {
460
- params.perplexity = true ;
471
+ } else if (arg == " --perplexity" || arg == " --all-logits " ) {
472
+ params.logits_all = true ;
461
473
} else if (arg == " --ppl-stride" ) {
462
474
if (++i >= argc) {
463
475
invalid_param = true ;
@@ -655,12 +667,15 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
655
667
printf (" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n " );
656
668
printf (" not recommended: doubles context memory required and no measurable increase in quality\n " );
657
669
printf (" --temp N temperature (default: %.1f)\n " , (double )params.temp );
658
- printf (" --perplexity compute perplexity over each ctx window of the prompt \n " );
670
+ printf (" --logits-all return logits for all tokens in the batch (default: disabled) \n " );
659
671
printf (" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n " );
660
672
printf (" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n " , params.hellaswag_tasks );
661
673
printf (" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n " , params.n_keep );
662
674
printf (" --draft N number of tokens to draft for speculative decoding (default: %d)\n " , params.n_draft );
663
675
printf (" --chunks N max number of chunks to process (default: %d, -1 = all)\n " , params.n_chunks );
676
+ printf (" -np N, --parallel N number of parallel sequences to decode (default: %d)\n " , params.n_parallel );
677
+ printf (" -ns N, --sequences N number of sequences to decode (default: %d)\n " , params.n_sequences );
678
+ printf (" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n " );
664
679
if (llama_mlock_supported ()) {
665
680
printf (" --mlock force system to keep model in RAM rather than swapping or compressing\n " );
666
681
}
@@ -685,7 +700,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
685
700
printf (" Not recommended since this is both slower and uses more VRAM.\n " );
686
701
#endif // GGML_USE_CUBLAS
687
702
#endif
688
- printf (" --export export the computation graph to 'llama.ggml'\n " );
689
703
printf (" --verbose-prompt print prompt before generation\n " );
690
704
fprintf (stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n " );
691
705
printf (" --lora FNAME apply LoRA adapter (implies --no-mmap)\n " );
@@ -738,7 +752,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
738
752
lparams.f16_kv = params.memory_f16 ;
739
753
lparams.use_mmap = params.use_mmap ;
740
754
lparams.use_mlock = params.use_mlock ;
741
- lparams.logits_all = params.perplexity ;
755
+ lparams.logits_all = params.logits_all ;
742
756
lparams.embedding = params.embedding ;
743
757
lparams.rope_freq_base = params.rope_freq_base ;
744
758
lparams.rope_freq_scale = params.rope_freq_scale ;
@@ -782,8 +796,9 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
782
796
{
783
797
LOG (" warming up the model with an empty run\n " );
784
798
785
- const std::vector<llama_token> tmp = { llama_token_bos (lctx), llama_token_eos (lctx), };
786
- llama_eval (lctx, tmp.data (), std::min (tmp.size (), (size_t ) params.n_batch ), 0 , params.n_threads );
799
+ std::vector<llama_token> tmp = { llama_token_bos (lctx), llama_token_eos (lctx), };
800
+ llama_decode (lctx, llama_batch_get_one (tmp.data (), std::min (tmp.size (), (size_t ) params.n_batch ), 0 , 0 ), params.n_threads );
801
+ llama_kv_cache_tokens_rm (lctx, -1 , -1 );
787
802
llama_reset_timings (lctx);
788
803
}
789
804
@@ -890,7 +905,7 @@ llama_token llama_sample_token(
890
905
891
906
llama_token id = 0 ;
892
907
893
- float * logits = llama_get_logits (ctx) + idx * n_vocab ;
908
+ float * logits = llama_get_logits_ith (ctx, idx) ;
894
909
895
910
// Apply params.logit_bias map
896
911
for (auto it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
@@ -941,19 +956,19 @@ llama_token llama_sample_token(
941
956
if (mirostat == 1 ) {
942
957
static float mirostat_mu = 2 .0f * mirostat_tau;
943
958
const int mirostat_m = 100 ;
944
- llama_sample_temperature (ctx, &cur_p, temp);
959
+ llama_sample_temp (ctx, &cur_p, temp);
945
960
id = llama_sample_token_mirostat (ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
946
961
} else if (mirostat == 2 ) {
947
962
static float mirostat_mu = 2 .0f * mirostat_tau;
948
- llama_sample_temperature (ctx, &cur_p, temp);
963
+ llama_sample_temp (ctx, &cur_p, temp);
949
964
id = llama_sample_token_mirostat_v2 (ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
950
965
} else {
951
966
// Temperature sampling
952
967
llama_sample_top_k (ctx, &cur_p, top_k, 1 );
953
968
llama_sample_tail_free (ctx, &cur_p, tfs_z, 1 );
954
969
llama_sample_typical (ctx, &cur_p, typical_p, 1 );
955
970
llama_sample_top_p (ctx, &cur_p, top_p, 1 );
956
- llama_sample_temperature (ctx, &cur_p, temp);
971
+ llama_sample_temp (ctx, &cur_p, temp);
957
972
958
973
{
959
974
const int n_top = 10 ;
@@ -1182,7 +1197,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
1182
1197
fprintf (stream, " color: %s # default: false\n " , params.use_color ? " true" : " false" );
1183
1198
fprintf (stream, " ctx_size: %d # default: 512\n " , params.n_ctx );
1184
1199
fprintf (stream, " escape: %s # default: false\n " , params.escape ? " true" : " false" );
1185
- fprintf (stream, " export: %s # default: false\n " , params.export_cgraph ? " true" : " false" );
1186
1200
fprintf (stream, " file: # never logged, see prompt instead. Can still be specified for input.\n " );
1187
1201
fprintf (stream, " frequency_penalty: %f # default: 0.0 \n " , params.frequency_penalty );
1188
1202
dump_string_yaml_multiline (stream, " grammar" , params.grammar .c_str ());
@@ -1256,6 +1270,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
1256
1270
fprintf (stream, " rope_freq_scale: %f # default: 1.0\n " , params.rope_freq_scale );
1257
1271
fprintf (stream, " seed: %d # default: -1 (random seed)\n " , params.seed );
1258
1272
fprintf (stream, " simple_io: %s # default: false\n " , params.simple_io ? " true" : " false" );
1273
+ fprintf (stream, " cont_batching: %s # default: false\n " , params.cont_batching ? " true" : " false" );
1259
1274
fprintf (stream, " temp: %f # default: 0.8\n " , params.temp );
1260
1275
1261
1276
const std::vector<float > tensor_split_vector (params.tensor_split , params.tensor_split + LLAMA_MAX_DEVICES);
0 commit comments