Skip to content

Commit bcb3ffe

Browse files
ggerganovslaren
authored andcommitted
llama : custom attention mask + parallel decoding + no context swaps (ggml-org#3228)
* tests : verify that RoPE is "additive" * llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask) * ggml : ggml_rope now takes a vector with positions instead of n_past * metal : add rope_f16 kernel + optimize cpy kernels * llama : unified KV cache + batch inference API * llama : add new llama_decode() API that works with llama_batch * llama : add cell_max heuristic for more efficient kv_cache * llama : extend llama_kv_cache API * llama : more robust cell_max heuristic + wip shift * metal : disable concurrency optimization * llama : add llama_kv_cache_shift_seq + no more context swaps * llama : apply K-cache roping for Falcon and Baichuan * speculative : fix KV cache management * parallel : example for serving multiple users in parallel * parallel : disable hot-plug to avoid cache fragmentation * fixes : speculative KV cache + llama worst-case graph * llama : extend batch API to select which logits to output * llama : fix worst case graph build * ggml-cuda : update rope implementation for parallel decoding (ggml-org#3254) * ggml-cuda : update rope implementation for parallel decoding * better solution for p0 computation * fix rope * simpler rope implementation --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * make : add parallel to build + fix static functions in llama.cpp * simple : fix token counting * parallel : various improvements * llama : fix cell_max logic + rename functions * parallel : try smaller batches when the KV cache is fragmented * parallel : fix sequence termination criteria * llama : silence errors KV cache errors * parallel : remove new line from prompt * parallel : process system prompt once + configurable paramters + llama API * parallel : remove question with short answers * parallel : count cache misses * parallel : print misses on each request * parallel : minor * llama : fix n_kv to never become 0 * parallel : rename hot-plug to continuous-batching * llama : improve llama_batch API + simplify parallel example * simple : add parallel decoding support * simple : improve comments + free batch * ggml-cuda : add rope f16, restore performance with parallel decoding (ggml-org#3272) * ggml-cuda : add rope f16, restore performance * offload KQ_mask with all models * fix rope shift --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * llama : disable MPI for now ggml-ci * train : make KQ_pos memory buffer permanent via dummy scale op * ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (ggml-org#3275) ggml-ci * parallel : fix bug (extra BOS) + smaller token_prev array * parallel : fix cases where the input prompts can overflow the batch * parallel : add disabled experimental batch chunking in powers of two * llama : llama.h formatting + comments * simple : add README.md * llama : fix kv cache heuristic when context is less than 32 * parallel : fix crash when `-n -1` * llama : simplify returns if/else branches * metal : use mm kernels for batch size > 2 * examples : utilize new llama_get_logits_ith() * examples : add example for batched decoding * examples : do not eval prompt 2 times (close ggml-org#3348) * server : clear the KV cache beyond n_past before llama_decode * server : avoid context swaps by shifting the KV cache --------- Co-authored-by: slaren <slarengh@gmail.com>
1 parent a53b8a8 commit bcb3ffe

35 files changed

+2687
-660
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ models-mnt
5151
/save-load-state
5252
/server
5353
/simple
54+
/batched
5455
/speculative
56+
/parallel
5557
/train-text-from-scratch
5658
/vdot
5759
build-info.h

Makefile

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Define the default target now so that it is always the first target
2-
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative tests/test-c.o
2+
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple batched save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative parallel tests/test-c.o
33

44
# Binaries only useful for tests
55
TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama
@@ -520,6 +520,9 @@ main: examples/main/main.cpp build-info.h ggml.
520520
simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS)
521521
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
522522

523+
batched: examples/batched/batched.cpp build-info.h ggml.o llama.o common.o $(OBJS)
524+
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
525+
523526
quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS)
524527
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
525528

@@ -566,6 +569,9 @@ beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o co
566569
speculative: examples/speculative/speculative.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)
567570
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
568571

572+
parallel: examples/parallel/parallel.cpp build-info.h ggml.o llama.o common.o $(OBJS)
573+
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
574+
569575
ifdef LLAMA_METAL
570576
metal: examples/metal/metal.cpp ggml.o $(OBJS)
571577
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)

common/common.cpp

+29-14
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
317317
break;
318318
}
319319
params.n_chunks = std::stoi(argv[i]);
320+
} else if (arg == "-np" || arg == "--parallel") {
321+
if (++i >= argc) {
322+
invalid_param = true;
323+
break;
324+
}
325+
params.n_parallel = std::stoi(argv[i]);
326+
} else if (arg == "-ns" || arg == "--sequences") {
327+
if (++i >= argc) {
328+
invalid_param = true;
329+
break;
330+
}
331+
params.n_sequences = std::stoi(argv[i]);
320332
} else if (arg == "-m" || arg == "--model") {
321333
if (++i >= argc) {
322334
invalid_param = true;
@@ -360,6 +372,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
360372
params.multiline_input = true;
361373
} else if (arg == "--simple-io") {
362374
params.simple_io = true;
375+
} else if (arg == "-cb" || arg == "--cont-batching") {
376+
params.cont_batching = true;
363377
} else if (arg == "--color") {
364378
params.use_color = true;
365379
} else if (arg == "--mlock") {
@@ -436,8 +450,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
436450
params.use_mmap = false;
437451
} else if (arg == "--numa") {
438452
params.numa = true;
439-
} else if (arg == "--export") {
440-
params.export_cgraph = true;
441453
} else if (arg == "--verbose-prompt") {
442454
params.verbose_prompt = true;
443455
} else if (arg == "-r" || arg == "--reverse-prompt") {
@@ -456,8 +468,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
456468
if (params.logdir.back() != DIRECTORY_SEPARATOR) {
457469
params.logdir += DIRECTORY_SEPARATOR;
458470
}
459-
} else if (arg == "--perplexity") {
460-
params.perplexity = true;
471+
} else if (arg == "--perplexity" || arg == "--all-logits") {
472+
params.logits_all = true;
461473
} else if (arg == "--ppl-stride") {
462474
if (++i >= argc) {
463475
invalid_param = true;
@@ -655,12 +667,15 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
655667
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
656668
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
657669
printf(" --temp N temperature (default: %.1f)\n", (double)params.temp);
658-
printf(" --perplexity compute perplexity over each ctx window of the prompt\n");
670+
printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
659671
printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
660672
printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
661673
printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
662674
printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft);
663675
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
676+
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
677+
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
678+
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
664679
if (llama_mlock_supported()) {
665680
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
666681
}
@@ -685,7 +700,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
685700
printf(" Not recommended since this is both slower and uses more VRAM.\n");
686701
#endif // GGML_USE_CUBLAS
687702
#endif
688-
printf(" --export export the computation graph to 'llama.ggml'\n");
689703
printf(" --verbose-prompt print prompt before generation\n");
690704
fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
691705
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
@@ -738,7 +752,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
738752
lparams.f16_kv = params.memory_f16;
739753
lparams.use_mmap = params.use_mmap;
740754
lparams.use_mlock = params.use_mlock;
741-
lparams.logits_all = params.perplexity;
755+
lparams.logits_all = params.logits_all;
742756
lparams.embedding = params.embedding;
743757
lparams.rope_freq_base = params.rope_freq_base;
744758
lparams.rope_freq_scale = params.rope_freq_scale;
@@ -782,8 +796,9 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
782796
{
783797
LOG("warming up the model with an empty run\n");
784798

785-
const std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
786-
llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads);
799+
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
800+
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads);
801+
llama_kv_cache_tokens_rm(lctx, -1, -1);
787802
llama_reset_timings(lctx);
788803
}
789804

@@ -890,7 +905,7 @@ llama_token llama_sample_token(
890905

891906
llama_token id = 0;
892907

893-
float * logits = llama_get_logits(ctx) + idx * n_vocab;
908+
float * logits = llama_get_logits_ith(ctx, idx);
894909

895910
// Apply params.logit_bias map
896911
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
@@ -941,19 +956,19 @@ llama_token llama_sample_token(
941956
if (mirostat == 1) {
942957
static float mirostat_mu = 2.0f * mirostat_tau;
943958
const int mirostat_m = 100;
944-
llama_sample_temperature(ctx, &cur_p, temp);
959+
llama_sample_temp(ctx, &cur_p, temp);
945960
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
946961
} else if (mirostat == 2) {
947962
static float mirostat_mu = 2.0f * mirostat_tau;
948-
llama_sample_temperature(ctx, &cur_p, temp);
963+
llama_sample_temp(ctx, &cur_p, temp);
949964
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
950965
} else {
951966
// Temperature sampling
952967
llama_sample_top_k (ctx, &cur_p, top_k, 1);
953968
llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
954969
llama_sample_typical (ctx, &cur_p, typical_p, 1);
955970
llama_sample_top_p (ctx, &cur_p, top_p, 1);
956-
llama_sample_temperature(ctx, &cur_p, temp);
971+
llama_sample_temp(ctx, &cur_p, temp);
957972

958973
{
959974
const int n_top = 10;
@@ -1182,7 +1197,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
11821197
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
11831198
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
11841199
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
1185-
fprintf(stream, "export: %s # default: false\n", params.export_cgraph ? "true" : "false");
11861200
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
11871201
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", params.frequency_penalty);
11881202
dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str());
@@ -1256,6 +1270,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
12561270
fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale);
12571271
fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed);
12581272
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
1273+
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
12591274
fprintf(stream, "temp: %f # default: 0.8\n", params.temp);
12601275

12611276
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES);

common/common.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ struct gpt_params {
4242
int32_t n_keep = 0; // number of tokens to keep from initial prompt
4343
int32_t n_draft = 16; // number of tokens to draft during speculative decoding
4444
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
45+
int32_t n_parallel = 1; // number of parallel sequences to decode
46+
int32_t n_sequences = 1; // number of sequences to decode
4547
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
4648
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
4749
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
@@ -107,16 +109,16 @@ struct gpt_params {
107109
bool interactive_first = false; // wait for user input immediately
108110
bool multiline_input = false; // reverse the usage of `\`
109111
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
112+
bool cont_batching = false; // insert new sequences for decoding on-the-fly
110113

111114
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
112115
bool ignore_eos = false; // ignore generated EOS tokens
113116
bool instruct = false; // instruction mode (used for Alpaca models)
114117
bool penalize_nl = true; // consider newlines as a repeatable token
115-
bool perplexity = false; // compute perplexity over the prompt
118+
bool logits_all = false; // return logits for all tokens in the batch
116119
bool use_mmap = true; // use mmap for faster loads
117120
bool use_mlock = false; // use mlock to keep model in memory
118121
bool numa = false; // attempt optimizations that help on some NUMA systems
119-
bool export_cgraph = false; // export the computation graph
120122
bool verbose_prompt = false; // print prompt tokens before generation
121123
};
122124

@@ -181,7 +183,7 @@ std::string llama_detokenize_bpe(
181183
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
182184
// - grammar: grammar to use for sampling, ignore if NULL
183185
// - last_tokens: needed for repetition penalty, ignore if empty
184-
// - idx: sample from llama_get_logits(ctx) + idx * n_vocab
186+
// - idx: sample from llama_get_logits_ith(ctx, idx)
185187
//
186188
// returns:
187189
// - token: sampled token

examples/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ else()
2323
add_subdirectory(train-text-from-scratch)
2424
add_subdirectory(convert-llama2c-to-ggml)
2525
add_subdirectory(simple)
26+
add_subdirectory(batched)
2627
add_subdirectory(speculative)
28+
add_subdirectory(parallel)
2729
add_subdirectory(embd-input)
2830
add_subdirectory(llama-bench)
2931
add_subdirectory(beam-search)

examples/baby-llama/baby-llama.cpp

+31-6
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,14 @@ static struct ggml_tensor * forward(
554554
struct ggml_tensor * kc = kv_self.k;
555555
struct ggml_tensor * vc = kv_self.v;
556556

557+
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
558+
{
559+
int * data = (int *) KQ_pos->data;
560+
for (int i = 0; i < N; ++i) {
561+
data[i] = n_past + i;
562+
}
563+
}
564+
557565
// inpL shape [n_embd,N,1,1]
558566
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
559567
for (int il = 0; il < n_layer; ++il) {
@@ -581,8 +589,8 @@ static struct ggml_tensor * forward(
581589
// wk shape [n_embd, n_embd, 1, 1]
582590
// Qcur shape [n_embd/n_head, n_head, N, 1]
583591
// Kcur shape [n_embd/n_head, n_head, N, 1]
584-
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
585-
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
592+
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);
593+
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);
586594

587595
// store key and value to memory
588596
{
@@ -808,9 +816,18 @@ static struct ggml_tensor * forward_batch(
808816
struct ggml_tensor * kc = kv_self.k;
809817
struct ggml_tensor * vc = kv_self.v;
810818

819+
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
820+
{
821+
int * data = (int *) KQ_pos->data;
822+
for (int i = 0; i < N; ++i) {
823+
data[i] = n_past + i;
824+
}
825+
}
826+
811827
// inpL shape [n_embd,N*n_batch,1]
812828
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
813829
assert_shape_2d(inpL, n_embd, N*n_batch);
830+
814831
for (int il = 0; il < n_layer; ++il) {
815832
struct ggml_tensor * inpSA = inpL;
816833

@@ -838,8 +855,8 @@ static struct ggml_tensor * forward_batch(
838855
// wk shape [n_embd, n_embd, 1, 1]
839856
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
840857
// Kcur shape [n_embd/n_head, n_head, N, n_batch]
841-
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
842-
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
858+
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0);
859+
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0);
843860
assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
844861
assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
845862

@@ -1097,6 +1114,14 @@ static struct ggml_tensor * forward_lora(
10971114
struct ggml_tensor * kc = kv_self.k;
10981115
struct ggml_tensor * vc = kv_self.v;
10991116

1117+
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1118+
{
1119+
int * data = (int *) KQ_pos->data;
1120+
for (int i = 0; i < N; ++i) {
1121+
data[i] = n_past + i;
1122+
}
1123+
}
1124+
11001125
// inpL shape [n_embd,N,1,1]
11011126
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
11021127
for (int il = 0; il < n_layer; ++il) {
@@ -1130,7 +1155,7 @@ static struct ggml_tensor * forward_lora(
11301155
model->layers[il].wqb,
11311156
cur)),
11321157
n_embd/n_head, n_head, N),
1133-
n_past, n_rot, 0, 0);
1158+
KQ_pos, n_rot, 0, 0);
11341159
struct ggml_tensor * Kcur = ggml_rope(ctx0,
11351160
ggml_reshape_3d(ctx0,
11361161
ggml_mul_mat(ctx0,
@@ -1139,7 +1164,7 @@ static struct ggml_tensor * forward_lora(
11391164
model->layers[il].wkb,
11401165
cur)),
11411166
n_embd/n_head, n_head, N),
1142-
n_past, n_rot, 0, 0);
1167+
KQ_pos, n_rot, 0, 0);
11431168

11441169
// store key and value to memory
11451170
{

examples/batched/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET batched)
2+
add_executable(${TARGET} batched.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/batched/README.md

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# llama.cpp/example/batched
2+
3+
The example demonstrates batched generation from a given prompt
4+
5+
```bash
6+
./batched ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is" 4
7+
8+
...
9+
10+
main: n_len = 32, n_ctx = 2048, n_parallel = 4, n_kv_req = 113
11+
12+
Hello my name is
13+
14+
main: generating 4 sequences ...
15+
16+
main: stream 0 finished
17+
main: stream 1 finished
18+
main: stream 2 finished
19+
main: stream 3 finished
20+
21+
sequence 0:
22+
23+
Hello my name is Shirley. I am a 25-year-old female who has been working for over 5 years as a b
24+
25+
sequence 1:
26+
27+
Hello my name is Renee and I'm a 32 year old female from the United States. I'm looking for a man between
28+
29+
sequence 2:
30+
31+
Hello my name is Diana. I am looking for a housekeeping job. I have experience with children and have my own transportation. I am
32+
33+
sequence 3:
34+
35+
Hello my name is Cody. I am a 3 year old neutered male. I am a very friendly cat. I am very playful and
36+
37+
main: decoded 108 tokens in 3.57 s, speed: 30.26 t/s
38+
39+
llama_print_timings: load time = 587.00 ms
40+
llama_print_timings: sample time = 2.56 ms / 112 runs ( 0.02 ms per token, 43664.72 tokens per second)
41+
llama_print_timings: prompt eval time = 4089.11 ms / 118 tokens ( 34.65 ms per token, 28.86 tokens per second)
42+
llama_print_timings: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second)
43+
llama_print_timings: total time = 4156.04 ms
44+
```

0 commit comments

Comments
 (0)