From 0e6db6fec1fd759b799a511018d6b21664cfb76d Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Sun, 8 Oct 2023 03:52:45 -0600 Subject: [PATCH 1/6] Fix mirostat state when using multiple sequences --- common/common.cpp | 21 +++++++++++++++------ common/common.h | 16 ++++++++++++++-- examples/parallel/parallel.cpp | 4 ++-- 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 0f55c33a713a7..752785f3aa856 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -940,10 +940,11 @@ llama_token llama_sample_token( struct llama_context * ctx, struct llama_context * ctx_guidance, struct llama_grammar * grammar, - const struct gpt_params & params, + struct gpt_params & params, const std::vector & last_tokens, std::vector & candidates, - int idx) { + const int idx, + llama_seq_id seq) { const int n_ctx = llama_n_ctx(ctx); const int n_vocab = llama_n_vocab(llama_get_model(ctx)); @@ -1011,15 +1012,23 @@ llama_token llama_sample_token( // Greedy sampling id = llama_sample_token_greedy(ctx, &cur_p); } else { + float * mirostat_mu = NULL; + if (mirostat > 0) { + seq = std::max(0, seq); // Deal with people passing -1 or something. + auto mu_it = params.sampler_state.find(seq); + if (mu_it == params.sampler_state.end()) { + const llama_sampler_state new_state = { 2.0f * mirostat_tau }; + mu_it = params.sampler_state.insert({seq, new_state}).first; + } + mirostat_mu = &mu_it->second.mirostat_mu; + } if (mirostat == 1) { - static float mirostat_mu = 2.0f * mirostat_tau; const int mirostat_m = 100; llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); + id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, mirostat_mu); } else if (mirostat == 2) { - static float mirostat_mu = 2.0f * mirostat_tau; llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu); + id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_mu); } else { // Temperature sampling size_t min_keep = std::max(1, params.n_probs); diff --git a/common/common.h b/common/common.h index c802152791797..3fa77e8776e49 100644 --- a/common/common.h +++ b/common/common.h @@ -33,6 +33,10 @@ // int32_t get_num_physical_cores(); +typedef struct llama_sampler_state { + float mirostat_mu; // mirostat sampler state +} llama_sampler_state; + struct gpt_params { uint32_t seed = -1; // RNG seed int32_t n_threads = get_num_physical_cores(); @@ -54,6 +58,9 @@ struct gpt_params { float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_scale = 0.0f; // RoPE frequency scaling factor + // per sequence sampler state + std::unordered_map sampler_state; + // sampling parameters int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled @@ -186,6 +193,9 @@ std::string llama_detokenize_bpe( // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function +// Note: When using multiple sequences, it is the caller's responsibility to delete +// the item in params.sampler_state when a sequence ends and samplers that rely on +// state are being used. // // required: // - ctx: context to use for sampling @@ -196,6 +206,7 @@ std::string llama_detokenize_bpe( // - grammar: grammar to use for sampling, ignore if NULL // - last_tokens: needed for repetition penalty, ignore if empty // - idx: sample from llama_get_logits_ith(ctx, idx) +// - seq: sequence id to associate sampler state with (currently only used by mirostat) // // returns: // - token: sampled token @@ -205,10 +216,11 @@ llama_token llama_sample_token( struct llama_context * ctx, struct llama_context * ctx_guidance, struct llama_grammar * grammar, - const struct gpt_params & params, + struct gpt_params & params, const std::vector & last_tokens, std::vector & candidates, - int idx = 0); + const int idx = 0, + llama_seq_id seq = 0); // // YAML utils diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 721888da7de94..8806cf7243fb1 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -339,7 +339,7 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i); + const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i, client.seq_id); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients @@ -384,7 +384,7 @@ int main(int argc, char ** argv) { n_total_prompt += client.n_prompt; n_total_gen += client.n_decoded; - + params.sampler_state.erase(client.seq_id); client.seq_id = -1; } From fad923a82ddba3a0a68ae3e75fd183d5d5119620 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Sun, 8 Oct 2023 09:16:59 -0600 Subject: [PATCH 2/6] Fix mirostat by completely refactoring sampling! --- Makefile | 86 ++++--- common/CMakeLists.txt | 2 + common/common.cpp | 237 +++++-------------- common/common.h | 68 +----- common/sampling.cpp | 160 +++++++++++++ common/sampling.h | 89 +++++++ examples/embd-input/embd-input-lib.cpp | 19 +- examples/infill/infill.cpp | 18 +- examples/main/main.cpp | 18 +- examples/parallel/parallel.cpp | 6 +- examples/save-load-state/save-load-state.cpp | 5 +- examples/server/server.cpp | 100 ++++---- examples/speculative/speculative.cpp | 12 +- 13 files changed, 467 insertions(+), 353 deletions(-) create mode 100644 common/sampling.cpp create mode 100644 common/sampling.h diff --git a/Makefile b/Makefile index 40187c4a25e62..b0a4d76d39ec1 100644 --- a/Makefile +++ b/Makefile @@ -172,6 +172,24 @@ else MK_CPPFLAGS += -DNDEBUG endif +ifdef LLAMA_SANITIZE_THREAD + MK_CFLAGS += -fsanitize=thread -g + MK_CXXFLAGS += -fsanitize=thread -g + MK_LDFLAGS += -fsanitize=thread -g +endif + +ifdef LLAMA_SANITIZE_ADDRESS + MK_CFLAGS += -fsanitize=address -fno-omit-frame-pointer -g + MK_CXXFLAGS += -fsanitize=address -fno-omit-frame-pointer -g + MK_LDFLAGS += -fsanitize=address -fno-omit-frame-pointer -g +endif + +ifdef LLAMA_SANITIZE_UNDEFINED + MK_CFLAGS += -fsanitize=undefined -g + MK_CXXFLAGS += -fsanitize=undefined -g + MK_LDFLAGS += -fsanitize=undefined -g +endif + ifdef LLAMA_SERVER_VERBOSE MK_CPPFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE) endif @@ -520,7 +538,13 @@ OBJS += ggml-alloc.o ggml-backend.o llama.o: llama.cpp ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h $(CXX) $(CXXFLAGS) -c $< -o $@ -common.o: common/common.cpp common/common.h build-info.h common/log.h +COMMON_H_DEPS = common/common.h common/sampling.h build-info.h common/log.h +COMMON_DEPS = $(COMMON_H_DEPS) common.o sampling.o + +common.o: common/common.cpp $(COMMON_H_DEPS) + $(CXX) $(CXXFLAGS) -c $< -o $@ + +sampling.o: common/sampling.cpp $(COMMON_H_DEPS) $(CXX) $(CXXFLAGS) -c $< -o $@ console.o: common/console.cpp common/console.h @@ -542,19 +566,19 @@ clean: # Examples # -main: examples/main/main.cpp build-info.h ggml.o llama.o common.o console.o grammar-parser.o $(OBJS) +main: examples/main/main.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) @echo @echo '==== Run ./main -h for help. ====' @echo -infill: examples/infill/infill.cpp build-info.h ggml.o llama.o common.o console.o grammar-parser.o $(OBJS) +infill: examples/infill/infill.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS) +simple: examples/simple/simple.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -batched: examples/batched/batched.cpp build-info.h ggml.o llama.o common.o $(OBJS) +batched: examples/batched/batched.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS) @@ -563,53 +587,53 @@ quantize: examples/quantize/quantize.cpp build-info.h ggml. quantize-stats: examples/quantize-stats/quantize-stats.cpp build-info.h ggml.o llama.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -perplexity: examples/perplexity/perplexity.cpp build-info.h ggml.o llama.o common.o $(OBJS) +perplexity: examples/perplexity/perplexity.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -embedding: examples/embedding/embedding.cpp build-info.h ggml.o llama.o common.o $(OBJS) +embedding: examples/embedding/embedding.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o $(OBJS) +save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS) +server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2) -$(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS) +$(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) --shared $(CXXFLAGS) $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) -embd-input-test: $(LIB_PRE)embdinput$(DSO_EXT) examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o common.o $(OBJS) +embd-input-test: $(LIB_PRE)embdinput$(DSO_EXT) examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %$(DSO_EXT),$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -lembdinput gguf: examples/gguf/gguf.cpp ggml.o llama.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp ggml.o llama.o common.o train.o $(OBJS) +train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp ggml.o llama.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -llama-bench: examples/llama-bench/llama-bench.cpp build-info.h ggml.o llama.o common.o $(OBJS) +llama-bench: examples/llama-bench/llama-bench.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o train.o $(OBJS) +baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o common.o $(OBJS) +beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -finetune: examples/finetune/finetune.cpp build-info.h ggml.o llama.o common.o train.o $(OBJS) +finetune: examples/finetune/finetune.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -export-lora: examples/export-lora/export-lora.cpp build-info.h ggml.o llama.o common.o $(OBJS) +export-lora: examples/export-lora/export-lora.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -speculative: examples/speculative/speculative.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS) +speculative: examples/speculative/speculative.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -parallel: examples/parallel/parallel.cpp build-info.h ggml.o llama.o common.o $(OBJS) +parallel: examples/parallel/parallel.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) ifdef LLAMA_METAL @@ -645,40 +669,40 @@ vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS) q8dot: pocs/vdot/q8dot.cpp ggml.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) -tests/test-llama-grammar: tests/test-llama-grammar.cpp build-info.h ggml.o common.o grammar-parser.o $(OBJS) +tests/test-llama-grammar: tests/test-llama-grammar.cpp build-info.h ggml.o $(COMMON_DEPS) grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -tests/test-grammar-parser: tests/test-grammar-parser.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS) +tests/test-grammar-parser: tests/test-grammar-parser.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -tests/test-double-float: tests/test-double-float.cpp build-info.h ggml.o llama.o common.o $(OBJS) +tests/test-double-float: tests/test-double-float.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -tests/test-grad0: tests/test-grad0.cpp build-info.h ggml.o llama.o common.o $(OBJS) +tests/test-grad0: tests/test-grad0.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -tests/test-opt: tests/test-opt.cpp build-info.h ggml.o llama.o common.o $(OBJS) +tests/test-opt: tests/test-opt.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -tests/test-quantize-fns: tests/test-quantize-fns.cpp build-info.h ggml.o llama.o common.o $(OBJS) +tests/test-quantize-fns: tests/test-quantize-fns.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -tests/test-quantize-perf: tests/test-quantize-perf.cpp build-info.h ggml.o llama.o common.o $(OBJS) +tests/test-quantize-perf: tests/test-quantize-perf.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -tests/test-sampling: tests/test-sampling.cpp build-info.h ggml.o llama.o common.o $(OBJS) +tests/test-sampling: tests/test-sampling.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -tests/test-tokenizer-0-falcon: tests/test-tokenizer-0-falcon.cpp build-info.h ggml.o llama.o common.o $(OBJS) +tests/test-tokenizer-0-falcon: tests/test-tokenizer-0-falcon.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -tests/test-tokenizer-0-llama: tests/test-tokenizer-0-llama.cpp build-info.h ggml.o llama.o common.o $(OBJS) +tests/test-tokenizer-0-llama: tests/test-tokenizer-0-llama.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -tests/test-tokenizer-1-bpe: tests/test-tokenizer-1-bpe.cpp build-info.h ggml.o llama.o common.o $(OBJS) +tests/test-tokenizer-1-bpe: tests/test-tokenizer-1-bpe.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -tests/test-tokenizer-1-llama: tests/test-tokenizer-1-llama.cpp build-info.h ggml.o llama.o common.o $(OBJS) +tests/test-tokenizer-1-llama: tests/test-tokenizer-1-llama.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) tests/test-c.o: tests/test-c.c llama.h diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 951aa8340c7e4..fbb0ff0952ac7 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -5,6 +5,8 @@ set(TARGET common) add_library(${TARGET} OBJECT common.h common.cpp + sampling.h + sampling.cpp console.h console.cpp grammar-parser.h diff --git a/common/common.cpp b/common/common.cpp index 752785f3aa856..4214e63afd87a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -107,6 +107,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { std::string arg; gpt_params default_params; const std::string arg_prefix = "--"; + llama_sampling_params & sparams = params.sampling_params; for (int i = 1; i < argc; i++) { arg = argv[i]; @@ -184,7 +185,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - params.top_k = std::stoi(argv[i]); + sparams.top_k = std::stoi(argv[i]); } else if (arg == "-c" || arg == "--ctx-size") { if (++i >= argc) { invalid_param = true; @@ -216,73 +217,73 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - params.top_p = std::stof(argv[i]); + sparams.top_p = std::stof(argv[i]); } else if (arg == "--temp") { if (++i >= argc) { invalid_param = true; break; } - params.temp = std::stof(argv[i]); + sparams.temp = std::stof(argv[i]); } else if (arg == "--tfs") { if (++i >= argc) { invalid_param = true; break; } - params.tfs_z = std::stof(argv[i]); + sparams.tfs_z = std::stof(argv[i]); } else if (arg == "--typical") { if (++i >= argc) { invalid_param = true; break; } - params.typical_p = std::stof(argv[i]); + sparams.typical_p = std::stof(argv[i]); } else if (arg == "--repeat-last-n") { if (++i >= argc) { invalid_param = true; break; } - params.repeat_last_n = std::stoi(argv[i]); + sparams.repeat_last_n = std::stoi(argv[i]); } else if (arg == "--repeat-penalty") { if (++i >= argc) { invalid_param = true; break; } - params.repeat_penalty = std::stof(argv[i]); + sparams.repeat_penalty = std::stof(argv[i]); } else if (arg == "--frequency-penalty") { if (++i >= argc) { invalid_param = true; break; } - params.frequency_penalty = std::stof(argv[i]); + sparams.frequency_penalty = std::stof(argv[i]); } else if (arg == "--presence-penalty") { if (++i >= argc) { invalid_param = true; break; } - params.presence_penalty = std::stof(argv[i]); + sparams.presence_penalty = std::stof(argv[i]); } else if (arg == "--mirostat") { if (++i >= argc) { invalid_param = true; break; } - params.mirostat = std::stoi(argv[i]); + sparams.mirostat = std::stoi(argv[i]); } else if (arg == "--mirostat-lr") { if (++i >= argc) { invalid_param = true; break; } - params.mirostat_eta = std::stof(argv[i]); + sparams.mirostat_eta = std::stof(argv[i]); } else if (arg == "--mirostat-ent") { if (++i >= argc) { invalid_param = true; break; } - params.mirostat_tau = std::stof(argv[i]); + sparams.mirostat_tau = std::stof(argv[i]); } else if (arg == "--cfg-negative-prompt") { if (++i >= argc) { invalid_param = true; break; } - params.cfg_negative_prompt = argv[i]; + sparams.cfg_negative_prompt = argv[i]; } else if (arg == "--cfg-negative-prompt-file") { if (++i >= argc) { invalid_param = true; @@ -294,16 +295,16 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.cfg_negative_prompt)); - if (!params.cfg_negative_prompt.empty() && params.cfg_negative_prompt.back() == '\n') { - params.cfg_negative_prompt.pop_back(); + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(sparams.cfg_negative_prompt)); + if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') { + sparams.cfg_negative_prompt.pop_back(); } } else if (arg == "--cfg-scale") { if (++i >= argc) { invalid_param = true; break; } - params.cfg_scale = std::stof(argv[i]); + sparams.cfg_scale = std::stof(argv[i]); } else if (arg == "-b" || arg == "--batch-size") { if (++i >= argc) { invalid_param = true; @@ -512,7 +513,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else if (arg == "--ignore-eos") { params.ignore_eos = true; } else if (arg == "--no-penalize-nl") { - params.penalize_nl = false; + sparams.penalize_nl = false; } else if (arg == "-l" || arg == "--logit-bias") { if (++i >= argc) { invalid_param = true; @@ -524,7 +525,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { std::string value_str; try { if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { - params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); + sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); } else { throw std::exception(); } @@ -627,6 +628,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { + const llama_sampling_params & sparams = params.sampling_params; + printf("usage: %s [options]\n", argv[0]); printf("\n"); printf("options:\n"); @@ -659,19 +662,19 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); - printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k); - printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); - printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z); - printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p); - printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n); - printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty); - printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty); - printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty); + printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); + printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); + printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); + printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); + printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n); + printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty); + printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty); + printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty); printf(" --mirostat N use Mirostat sampling.\n"); printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); - printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat); - printf(" --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta); - printf(" --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau); + printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat); + printf(" --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)sparams.mirostat_eta); + printf(" --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)sparams.mirostat_tau); printf(" -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n"); printf(" modifies the likelihood of token appearing in the completion,\n"); printf(" i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); @@ -682,7 +685,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" negative prompt to use for guidance. (default: empty)\n"); printf(" --cfg-negative-prompt-file FNAME\n"); printf(" negative prompt file to use for guidance. (default: empty)\n"); - printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); + printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale); printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n"); printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n"); printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n"); @@ -690,7 +693,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --no-penalize-nl do not penalize newline token\n"); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); - printf(" --temp N temperature (default: %.1f)\n", (double)params.temp); + printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp); printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n"); printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks); @@ -840,7 +843,7 @@ std::tuple llama_init_from_gpt_par } if (params.ignore_eos) { - params.logit_bias[llama_token_eos(lctx)] = -INFINITY; + params.sampling_params.logit_bias[llama_token_eos(lctx)] = -INFINITY; } { @@ -932,136 +935,6 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector & last_tokens, - std::vector & candidates, - const int idx, - llama_seq_id seq) { - const int n_ctx = llama_n_ctx(ctx); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - - const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; - const float top_p = params.top_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; - const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; - const float repeat_penalty = params.repeat_penalty; - const float alpha_presence = params.presence_penalty; - const float alpha_frequency = params.frequency_penalty; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; - const bool penalize_nl = params.penalize_nl; - - llama_token id = 0; - - float * logits = llama_get_logits_ith(ctx, idx); - - // Apply params.logit_bias map - for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { - logits[it->first] += it->second; - } - - candidates.clear(); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - - llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; - - if (ctx_guidance) { - llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale); - } - - // apply penalties - if (!last_tokens.empty()) { - const float nl_logit = logits[llama_token_nl(ctx)]; - const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx); - - llama_sample_repetition_penalty(ctx, &cur_p, - last_tokens.data() + last_tokens.size() - last_n_repeat, - last_n_repeat, repeat_penalty); - llama_sample_frequency_and_presence_penalties(ctx, &cur_p, - last_tokens.data() + last_tokens.size() - last_n_repeat, - last_n_repeat, alpha_frequency, alpha_presence); - - if (!penalize_nl) { - for (size_t idx = 0; idx < cur_p.size; idx++) { - if (cur_p.data[idx].id == llama_token_nl(ctx)) { - cur_p.data[idx].logit = nl_logit; - break; - } - } - } - } - - if (grammar != NULL) { - llama_sample_grammar(ctx, &cur_p, grammar); - } - - if (temp <= 0) { - // Greedy sampling - id = llama_sample_token_greedy(ctx, &cur_p); - } else { - float * mirostat_mu = NULL; - if (mirostat > 0) { - seq = std::max(0, seq); // Deal with people passing -1 or something. - auto mu_it = params.sampler_state.find(seq); - if (mu_it == params.sampler_state.end()) { - const llama_sampler_state new_state = { 2.0f * mirostat_tau }; - mu_it = params.sampler_state.insert({seq, new_state}).first; - } - mirostat_mu = &mu_it->second.mirostat_mu; - } - if (mirostat == 1) { - const int mirostat_m = 100; - llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, mirostat_mu); - } else if (mirostat == 2) { - llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_mu); - } else { - // Temperature sampling - size_t min_keep = std::max(1, params.n_probs); - llama_sample_top_k (ctx, &cur_p, top_k, min_keep); - llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep); - llama_sample_typical (ctx, &cur_p, typical_p, min_keep); - llama_sample_top_p (ctx, &cur_p, top_p, min_keep); - llama_sample_temp(ctx, &cur_p, temp); - - { - const int n_top = 10; - LOG("top %d candidates:\n", n_top); - - for (int i = 0; i < n_top; i++) { - const llama_token id = cur_p.data[i].id; - LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p); - } - } - - id = llama_sample_token(ctx, &cur_p); - - LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str()); - } - } - // printf("`%d`", candidates_p.size); - - if (grammar != NULL) { - llama_grammar_accept_token(ctx, grammar, id); - } - - return id; -} - // // YAML utils // @@ -1213,6 +1086,8 @@ std::string get_sortable_timestamp() { void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx, const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { + const llama_sampling_params & sparams = params.sampling_params; + fprintf(stream, "build_commit: %s\n", BUILD_COMMIT); fprintf(stream, "build_number: %d\n", BUILD_NUMBER); fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false"); @@ -1259,21 +1134,21 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str()); fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch); - dump_string_yaml_multiline(stream, "cfg_negative_prompt", params.cfg_negative_prompt.c_str()); - fprintf(stream, "cfg_scale: %f # default: 1.0\n", params.cfg_scale); + dump_string_yaml_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str()); + fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale); fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); - fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", params.frequency_penalty); + fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty); dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str()); fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); - const auto logit_bias_eos = params.logit_bias.find(llama_token_eos(lctx)); - const bool ignore_eos = logit_bias_eos != params.logit_bias.end() && logit_bias_eos->second == -INFINITY; + const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(lctx)); + const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); dump_string_yaml_multiline(stream, "in_prefix", params.input_prefix.c_str()); @@ -1286,7 +1161,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); fprintf(stream, "logit_bias:\n"); - for (std::pair lb : params.logit_bias) { + for (std::pair lb : sparams.logit_bias) { if (ignore_eos && lb.first == logit_bias_eos->first) { continue; } @@ -1310,30 +1185,30 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false"); - fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", params.mirostat); - fprintf(stream, "mirostat_ent: %f # default: 5.0\n", params.mirostat_tau); - fprintf(stream, "mirostat_lr: %f # default: 0.1\n", params.mirostat_eta); + fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); + fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); + fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str()); fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); - fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", params.n_probs); + fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs); fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); fprintf(stream, "no_mul_mat_q: %s # default: false\n", !params.mul_mat_q ? "true" : "false"); - fprintf(stream, "no_penalize_nl: %s # default: false\n", !params.penalize_nl ? "true" : "false"); + fprintf(stream, "no_penalize_nl: %s # default: false\n", !sparams.penalize_nl ? "true" : "false"); fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false"); fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); - fprintf(stream, "presence_penalty: %f # default: 0.0\n", params.presence_penalty); + fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty); dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str()); fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens); fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false"); - fprintf(stream, "repeat_penalty: %f # default: 1.1\n", params.repeat_penalty); + fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty); fprintf(stream, "reverse_prompt:\n"); for (std::string ap : params.antiprompt) { @@ -1351,15 +1226,15 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); - fprintf(stream, "temp: %f # default: 0.8\n", params.temp); + fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES); dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector); - fprintf(stream, "tfs: %f # default: 1.0\n", params.tfs_z); + fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z); fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency()); - fprintf(stream, "top_k: %d # default: 40\n", params.top_k); - fprintf(stream, "top_p: %f # default: 0.95\n", params.top_p); - fprintf(stream, "typical_p: %f # default: 1.0\n", params.typical_p); + fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); + fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); + fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); } diff --git a/common/common.h b/common/common.h index 3fa77e8776e49..fa115536b64a0 100644 --- a/common/common.h +++ b/common/common.h @@ -4,6 +4,8 @@ #include "llama.h" +#include "sampling.h" + #define LOG_NO_FILE_LINE_FUNCTION #include "log.h" @@ -33,10 +35,6 @@ // int32_t get_num_physical_cores(); -typedef struct llama_sampler_state { - float mirostat_mu; // mirostat sampler state -} llama_sampler_state; - struct gpt_params { uint32_t seed = -1; // RNG seed int32_t n_threads = get_num_physical_cores(); @@ -53,34 +51,12 @@ struct gpt_params { int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t n_beams = 0; // if non-zero then use beam search of given width. float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_scale = 0.0f; // RoPE frequency scaling factor - // per sequence sampler state - std::unordered_map sampler_state; - - // sampling parameters - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typical_p = 1.00f; // 1.0 = disabled - float temp = 0.80f; // 1.0 = disabled - float repeat_penalty = 1.10f; // 1.0 = disabled - int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float frequency_penalty = 0.00f; // 0.0 = disabled - float presence_penalty = 0.00f; // 0.0 = disabled - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - - std::unordered_map logit_bias; // logit bias for specific tokens - - // Classifier-Free Guidance - // https://arxiv.org/abs/2306.17806 - std::string cfg_negative_prompt; // string to help guidance - float cfg_scale = 1.f; // How strong is guidance + // // sampling parameters + struct llama_sampling_params sampling_params; std::string model = "models/7B/ggml-model-f16.gguf"; // model path std::string model_draft = ""; // draft model for speculative decoding @@ -122,7 +98,6 @@ struct gpt_params { bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens bool instruct = false; // instruction mode (used for Alpaca models) - bool penalize_nl = true; // consider newlines as a repeatable token bool logits_all = false; // return logits for all tokens in the batch bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory @@ -187,41 +162,6 @@ std::string llama_detokenize_bpe( llama_context * ctx, const std::vector & tokens); -// -// Sampling utils -// - -// this is a common sampling function used across the examples for convenience -// it can serve as a starting point for implementing your own sampling function -// Note: When using multiple sequences, it is the caller's responsibility to delete -// the item in params.sampler_state when a sequence ends and samplers that rely on -// state are being used. -// -// required: -// - ctx: context to use for sampling -// - params: sampling parameters -// -// optional: -// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL -// - grammar: grammar to use for sampling, ignore if NULL -// - last_tokens: needed for repetition penalty, ignore if empty -// - idx: sample from llama_get_logits_ith(ctx, idx) -// - seq: sequence id to associate sampler state with (currently only used by mirostat) -// -// returns: -// - token: sampled token -// - candidates: vector of candidate tokens -// -llama_token llama_sample_token( - struct llama_context * ctx, - struct llama_context * ctx_guidance, - struct llama_grammar * grammar, - struct gpt_params & params, - const std::vector & last_tokens, - std::vector & candidates, - const int idx = 0, - llama_seq_id seq = 0); - // // YAML utils // diff --git a/common/sampling.cpp b/common/sampling.cpp new file mode 100644 index 0000000000000..5e8ad1db43582 --- /dev/null +++ b/common/sampling.cpp @@ -0,0 +1,160 @@ +#include "sampling.h" + +llama_sampling_state::~llama_sampling_state() { + for (auto & it : sequence_states) { + if (it.second.grammar != NULL) { + llama_grammar_free(it.second.grammar); + it.second.grammar = NULL; + } + } +} + +llama_sampling_state llama_sampling_state_init(const struct gpt_params & params, llama_grammar * grammar) { + llama_sampling_state result; + + result.params = params.sampling_params; + result.grammar = grammar; + return result; +} + +// Creates the state if it doesn't exist, so this always return something. +static llama_sampler_sequence_state & sampling_get_sequence_state(llama_sampling_state & state, const llama_seq_id seq) { + const auto it = state.sequence_states.find(seq); + if (it != state.sequence_states.end()) { + return it->second; + } + llama_sampler_sequence_state new_state = { + 2.0f * state.params.mirostat_tau, + state.grammar != NULL ? llama_grammar_copy(state.grammar) : NULL, + }; + return state.sequence_states.insert({seq, new_state}).first->second; +} + +bool llama_sampling_state_reset(llama_sampling_state & state, const llama_seq_id seq) { + const auto it = state.sequence_states.find(seq); + if (it == state.sequence_states.end()) return false; + if (it->second.grammar != NULL) { + llama_grammar_free(it->second.grammar); + it->second.grammar = NULL; + } + state.sequence_states.erase(it); + return true; +} + +llama_token llama_sample_token( + struct llama_context * ctx, + struct llama_context * ctx_guidance, + struct llama_sampling_state & state, + const std::vector & last_tokens, + std::vector & candidates, + const int idx, + llama_seq_id seq) { + const int n_ctx = llama_n_ctx(ctx); + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + + const llama_sampling_params & params = state.params; + const float temp = params.temp; + const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; + const float top_p = params.top_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; + const float repeat_penalty = params.repeat_penalty; + const float alpha_presence = params.presence_penalty; + const float alpha_frequency = params.frequency_penalty; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; + const bool penalize_nl = params.penalize_nl; + + llama_token id = 0; + + float * logits = llama_get_logits_ith(ctx, idx); + + // Apply params.logit_bias map + for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { + logits[it->first] += it->second; + } + + candidates.clear(); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; + + if (ctx_guidance) { + llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale); + } + + // apply penalties + if (!last_tokens.empty()) { + const float nl_logit = logits[llama_token_nl(ctx)]; + const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx); + + llama_sample_repetition_penalty(ctx, &cur_p, + last_tokens.data() + last_tokens.size() - last_n_repeat, + last_n_repeat, repeat_penalty); + llama_sample_frequency_and_presence_penalties(ctx, &cur_p, + last_tokens.data() + last_tokens.size() - last_n_repeat, + last_n_repeat, alpha_frequency, alpha_presence); + + if (!penalize_nl) { + for (size_t idx = 0; idx < cur_p.size; idx++) { + if (cur_p.data[idx].id == llama_token_nl(ctx)) { + cur_p.data[idx].logit = nl_logit; + break; + } + } + } + } + + llama_sampler_sequence_state & seq_state = sampling_get_sequence_state(state, seq); + + if (seq_state.grammar != NULL) { + llama_sample_grammar(ctx, &cur_p, seq_state.grammar); + } + + if (temp <= 0) { + // Greedy sampling + id = llama_sample_token_greedy(ctx, &cur_p); + } else { + if (mirostat == 1) { + const int mirostat_m = 100; + llama_sample_temp(ctx, &cur_p, temp); + id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &seq_state.mirostat_mu); + } else if (mirostat == 2) { + llama_sample_temp(ctx, &cur_p, temp); + id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &seq_state.mirostat_mu); + } else { + // Temperature sampling + size_t min_keep = std::max(1, params.n_probs); + llama_sample_top_k (ctx, &cur_p, top_k, min_keep); + llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep); + llama_sample_typical (ctx, &cur_p, typical_p, min_keep); + llama_sample_top_p (ctx, &cur_p, top_p, min_keep); + llama_sample_temp(ctx, &cur_p, temp); + + { + const int n_top = 10; + LOG("top %d candidates:\n", n_top); + + for (int i = 0; i < n_top; i++) { + const llama_token id = cur_p.data[i].id; + LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p); + } + } + + id = llama_sample_token(ctx, &cur_p); + + LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str()); + } + } + // printf("`%d`", candidates_p.size); + + if (seq_state.grammar != NULL) { + llama_grammar_accept_token(ctx, seq_state.grammar, id); + } + + return id; +} diff --git a/common/sampling.h b/common/sampling.h new file mode 100644 index 0000000000000..6c5f8749e27e0 --- /dev/null +++ b/common/sampling.h @@ -0,0 +1,89 @@ +#pragma once + +#include "llama.h" + +#include +#include +#include + +// sampling parameters +typedef struct llama_sampling_params { + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float tfs_z = 1.00f; // 1.0 = disabled + float typical_p = 1.00f; // 1.0 = disabled + float temp = 0.80f; // 1.0 = disabled + float repeat_penalty = 1.10f; // 1.0 = disabled + int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float frequency_penalty = 0.00f; // 0.0 = disabled + float presence_penalty = 0.00f; // 0.0 = disabled + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + + bool penalize_nl = true; // consider newlines as a repeatable token + + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + + // Classifier-Free Guidance + // https://arxiv.org/abs/2306.17806 + std::string cfg_negative_prompt; // string to help guidance + float cfg_scale = 1.f; // How strong is guidance + + std::unordered_map logit_bias; // logit bias for specific tokens + +} llama_sampling_params; + +// per-sequence sampler state +typedef struct llama_sampler_sequence_state { + float mirostat_mu; // mirostat sampler state + llama_grammar * grammar; +} llama_sampler_sequence_state; + +// general sampler state +typedef struct llama_sampling_state { + ~llama_sampling_state(); + + llama_sampling_params params; + + std::unordered_map sequence_states; + + llama_grammar * grammar; +} llama_sampling_state; + +#include "common.h" + +// Create a new sampling state instance. +llama_sampling_state llama_sampling_state_init(const struct gpt_params & params, llama_grammar * grammar); + +bool llama_sampling_state_reset(llama_sampling_state & state, const llama_seq_id seq = 0); + +// this is a common sampling function used across the examples for convenience +// it can serve as a starting point for implementing your own sampling function +// Note: When using multiple sequences, it is the caller's responsibility to delete +// the item in params.sampler_state when a sequence ends and samplers that rely on +// state are being used. +// +// required: +// - ctx: context to use for sampling +// - params: sampling parameters +// +// optional: +// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL +// - grammar: grammar to use for sampling, ignore if NULL +// - last_tokens: needed for repetition penalty, ignore if empty +// - idx: sample from llama_get_logits_ith(ctx, idx) +// - seq: sequence id to associate sampler state with (currently only used by mirostat) +// +// returns: +// - token: sampled token +// - candidates: vector of candidate tokens +// +llama_token llama_sample_token( + struct llama_context * ctx, + struct llama_context * ctx_guidance, + struct llama_sampling_state & state, + const std::vector & last_tokens, + std::vector & candidates, + const int idx = 0, + llama_seq_id seq = 0); diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 99e6bdad5ac45..87a5a1c26f88b 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -128,21 +128,22 @@ bool eval_string(struct MyModel * mymodel,const char* str){ llama_token sampling_id(struct MyModel* mymodel) { llama_context* ctx = mymodel->ctx; gpt_params params = mymodel->params; + llama_sampling_params & sparams = params.sampling_params; // int n_ctx = llama_n_ctx(ctx); // out of user input, sample next token - const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : params.top_k; - const float top_p = params.top_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; + const float temp = sparams.temp; + const int32_t top_k = sparams.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : sparams.top_k; + const float top_p = sparams.top_p; + const float tfs_z = sparams.tfs_z; + const float typical_p = sparams.typical_p; // const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; // const float repeat_penalty = params.repeat_penalty; // const float alpha_presence = params.presence_penalty; // const float alpha_frequency = params.frequency_penalty; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; + const int mirostat = sparams.mirostat; + const float mirostat_tau = sparams.mirostat_tau; + const float mirostat_eta = sparams.mirostat_eta; // const bool penalize_nl = params.penalize_nl; llama_token id = 0; @@ -151,7 +152,7 @@ llama_token sampling_id(struct MyModel* mymodel) { auto n_vocab = llama_n_vocab(llama_get_model(ctx)); // Apply params.logit_bias map - for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { + for (auto it = sparams.logit_bias.begin(); it != sparams.logit_bias.end(); it++) { logits[it->first] += it->second; } diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 9ec75ce425b2a..cde82267874c2 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -104,6 +104,7 @@ static void sigint_handler(int signo) { int main(int argc, char ** argv) { gpt_params params; + llama_sampling_params & sparams = params.sampling_params; g_params = ¶ms; if (!gpt_params_parse(argc, argv, params)) { @@ -206,7 +207,7 @@ int main(int argc, char ** argv) { // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); std::tie(model, ctx) = llama_init_from_gpt_params(params); - if (params.cfg_scale > 1.f) { + if (sparams.cfg_scale > 1.f) { struct llama_context_params lparams = llama_context_params_from_gpt_params(params); ctx_guidance = llama_new_context_with_model(model, lparams); } @@ -257,9 +258,9 @@ int main(int argc, char ** argv) { int guidance_offset = 0; int original_prompt_len = 0; if (ctx_guidance) { - LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt)); + LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt)); - guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos); + guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos); LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp)); std::vector original_inp = ::llama_tokenize(ctx, params.prompt, add_bos); @@ -300,7 +301,7 @@ int main(int argc, char ** argv) { if (ctx_guidance) { LOG_TEE("\n"); - LOG_TEE("%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); + LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str()); LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); for (int i = 0; i < (int) guidance_inp.size(); i++) { LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str()); @@ -346,7 +347,7 @@ int main(int argc, char ** argv) { } } LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", - params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); + sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("\n\n"); @@ -364,8 +365,8 @@ int main(int argc, char ** argv) { LOG_TEE("\n"); { - auto it = params.logit_bias.find(llama_token_eos(ctx)); - if (it != params.logit_bias.end() && it->second == -INFINITY) { + auto it = sparams.logit_bias.find(llama_token_eos(ctx)); + if (it != sparams.logit_bias.end() && it->second == -INFINITY) { LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); } } @@ -422,6 +423,7 @@ int main(int argc, char ** argv) { const int n_vocab = llama_n_vocab(model); + llama_sampling_state sampling_state = llama_sampling_state_init(params, grammar); std::vector candidates; candidates.reserve(n_vocab); @@ -540,7 +542,7 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates); + const llama_token id = llama_sample_token(ctx, ctx_guidance, sampling_state, last_tokens, candidates); last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 775a5a201e5b8..89abfd55b0264 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -109,6 +109,7 @@ int main(int argc, char ** argv) { if (!gpt_params_parse(argc, argv, params)) { return 1; } + llama_sampling_params & sparams = params.sampling_params; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("main", "log")); @@ -179,7 +180,7 @@ int main(int argc, char ** argv) { // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); std::tie(model, ctx) = llama_init_from_gpt_params(params); - if (params.cfg_scale > 1.f) { + if (sparams.cfg_scale > 1.f) { struct llama_context_params lparams = llama_context_params_from_gpt_params(params); ctx_guidance = llama_new_context_with_model(model, lparams); } @@ -257,9 +258,9 @@ int main(int argc, char ** argv) { int guidance_offset = 0; int original_prompt_len = 0; if (ctx_guidance) { - LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt)); + LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt)); - guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos); + guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos); LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp)); std::vector original_inp = ::llama_tokenize(ctx, params.prompt, add_bos); @@ -343,7 +344,7 @@ int main(int argc, char ** argv) { if (ctx_guidance) { LOG_TEE("\n"); - LOG_TEE("%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); + LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str()); LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); for (int i = 0; i < (int) guidance_inp.size(); i++) { LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str()); @@ -395,7 +396,7 @@ int main(int argc, char ** argv) { } } LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", - params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); + sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("\n\n"); @@ -413,8 +414,8 @@ int main(int argc, char ** argv) { LOG_TEE("\n"); { - auto it = params.logit_bias.find(llama_token_eos(ctx)); - if (it != params.logit_bias.end() && it->second == -INFINITY) { + auto it = sparams.logit_bias.find(llama_token_eos(ctx)); + if (it != sparams.logit_bias.end() && it->second == -INFINITY) { LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); } } @@ -469,6 +470,7 @@ int main(int argc, char ** argv) { const int n_vocab = llama_n_vocab(model); + llama_sampling_state sampling_state = llama_sampling_state_init(params, grammar); std::vector candidates; candidates.reserve(n_vocab); @@ -625,7 +627,7 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } - const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates); + const llama_token id = llama_sample_token(ctx, ctx_guidance, sampling_state, last_tokens, candidates); last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 8806cf7243fb1..cb4c4ba96964e 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -125,6 +125,8 @@ int main(int argc, char ** argv) { params.logits_all = true; std::tie(model, ctx) = llama_init_from_gpt_params(params); + llama_sampling_state sampling_state = llama_sampling_state_init(params, NULL); + // load the prompts from an external file if there are any if (params.prompt.empty()) { printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n"); @@ -339,7 +341,7 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i, client.seq_id); + const llama_token id = llama_sample_token(ctx, NULL, sampling_state, client.tokens_prev, candidates, client.i_batch - i, client.seq_id); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients @@ -384,7 +386,7 @@ int main(int argc, char ** argv) { n_total_prompt += client.n_prompt; n_total_gen += client.n_decoded; - params.sampler_state.erase(client.seq_id); + llama_sampling_state_reset(sampling_state, client.seq_id); client.seq_id = -1; } diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index acc6dbdfd07d0..f9e3c98a38a40 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -8,9 +8,10 @@ int main(int argc, char ** argv) { gpt_params params; + llama_sampling_params & sparams = params.sampling_params; params.seed = 42; params.n_threads = 4; - params.repeat_last_n = 64; + sparams.repeat_last_n = 64; params.prompt = "The quick brown fox"; if (!gpt_params_parse(argc, argv, params)) { @@ -24,7 +25,7 @@ int main(int argc, char ** argv) { } auto n_past = 0; - auto last_n_tokens_data = std::vector(params.repeat_last_n, 0); + auto last_n_tokens_data = std::vector(sparams.repeat_last_n, 0); // init llama_model * model; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c53a64867336f..08bdaa1899d43 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -200,6 +200,7 @@ struct llama_server_context llama_model *model = nullptr; llama_context *ctx = nullptr; gpt_params params; + llama_sampling_state sampling_state; int n_ctx; grammar_parser::parse_state parsed_grammar; @@ -254,6 +255,7 @@ struct llama_server_context if (grammar != nullptr) { llama_grammar_free(grammar); grammar = nullptr; + sampling_state = llama_sampling_state_init(params, NULL); } } @@ -329,8 +331,8 @@ struct llama_server_context grammar_parser::print_grammar(stderr, parsed_grammar); { - auto it = params.logit_bias.find(llama_token_eos(ctx)); - if (it != params.logit_bias.end() && it->second == -INFINITY) { + auto it = params.sampling_params.logit_bias.find(llama_token_eos(ctx)); + if (it != params.sampling_params.logit_bias.end() && it->second == -INFINITY) { LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); } } @@ -339,6 +341,7 @@ struct llama_server_context grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } + sampling_state = llama_sampling_state_init(params, grammar); return true; } @@ -539,12 +542,12 @@ struct llama_server_context std::vector candidates; candidates.reserve(llama_n_vocab(model)); - result.tok = llama_sample_token(ctx, NULL, grammar, params, last_n_tokens, candidates); + result.tok = llama_sample_token(ctx, NULL, sampling_state, last_n_tokens, candidates); llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - const int32_t n_probs = params.n_probs; - if (params.temp <= 0 && n_probs > 0) + const int32_t n_probs = params.sampling_params.n_probs; + if (params.sampling_params.temp <= 0 && n_probs > 0) { // For llama_sample_token_greedy we need to sort candidates llama_sample_softmax(ctx, &candidates_p); @@ -619,7 +622,7 @@ struct llama_server_context const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok); generated_text += token_text; - if (params.n_probs > 0) + if (params.sampling_params.n_probs > 0) { generated_token_probs.push_back(token_with_probs); } @@ -1007,34 +1010,35 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, static json format_generation_settings(llama_server_context &llama) { - const auto eos_bias = llama.params.logit_bias.find(llama_token_eos(llama.ctx)); - const bool ignore_eos = eos_bias != llama.params.logit_bias.end() && + const auto & sparams = llama.params.sampling_params; + const auto eos_bias = sparams.logit_bias.find(llama_token_eos(llama.ctx)); + const bool ignore_eos = eos_bias != sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); return json{ {"n_ctx", llama.n_ctx}, {"model", llama.params.model_alias}, {"seed", llama.params.seed}, - {"temp", llama.params.temp}, - {"top_k", llama.params.top_k}, - {"top_p", llama.params.top_p}, - {"tfs_z", llama.params.tfs_z}, - {"typical_p", llama.params.typical_p}, - {"repeat_last_n", llama.params.repeat_last_n}, - {"repeat_penalty", llama.params.repeat_penalty}, - {"presence_penalty", llama.params.presence_penalty}, - {"frequency_penalty", llama.params.frequency_penalty}, - {"mirostat", llama.params.mirostat}, - {"mirostat_tau", llama.params.mirostat_tau}, - {"mirostat_eta", llama.params.mirostat_eta}, - {"penalize_nl", llama.params.penalize_nl}, + {"temp", sparams.temp}, + {"top_k", sparams.top_k}, + {"top_p", sparams.top_p}, + {"tfs_z", sparams.tfs_z}, + {"typical_p", sparams.typical_p}, + {"repeat_last_n", sparams.repeat_last_n}, + {"repeat_penalty", sparams.repeat_penalty}, + {"presence_penalty", sparams.presence_penalty}, + {"frequency_penalty", sparams.frequency_penalty}, + {"mirostat", sparams.mirostat}, + {"mirostat_tau", sparams.mirostat_tau}, + {"mirostat_eta", sparams.mirostat_eta}, + {"penalize_nl", sparams.penalize_nl}, {"stop", llama.params.antiprompt}, {"n_predict", llama.params.n_predict}, {"n_keep", llama.params.n_keep}, {"ignore_eos", ignore_eos}, {"stream", llama.stream}, - {"logit_bias", llama.params.logit_bias}, - {"n_probs", llama.params.n_probs}, + {"logit_bias", sparams.logit_bias}, + {"n_probs", sparams.n_probs}, {"grammar", llama.params.grammar}, }; } @@ -1083,7 +1087,7 @@ static json format_final_response(llama_server_context &llama, const std::string {"timings", format_timings(llama)}, }; - if (llama.params.n_probs > 0) + if (llama.params.sampling_params.n_probs > 0) { res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); } @@ -1099,7 +1103,7 @@ static json format_partial_response( {"stop", false}, }; - if (llama.params.n_probs > 0) + if (llama.params.sampling_params.n_probs > 0) { res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); } @@ -1131,26 +1135,28 @@ static T json_value(const json &body, const std::string &key, const T &default_v static void parse_options_completion(const json &body, llama_server_context &llama) { gpt_params default_params; + const auto & default_sparams = default_params.sampling_params; + auto & sparams = llama.params.sampling_params; llama.stream = json_value(body, "stream", false); llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict); - llama.params.top_k = json_value(body, "top_k", default_params.top_k); - llama.params.top_p = json_value(body, "top_p", default_params.top_p); - llama.params.tfs_z = json_value(body, "tfs_z", default_params.tfs_z); - llama.params.typical_p = json_value(body, "typical_p", default_params.typical_p); - llama.params.repeat_last_n = json_value(body, "repeat_last_n", default_params.repeat_last_n); - llama.params.temp = json_value(body, "temperature", default_params.temp); - llama.params.repeat_penalty = json_value(body, "repeat_penalty", default_params.repeat_penalty); - llama.params.presence_penalty = json_value(body, "presence_penalty", default_params.presence_penalty); - llama.params.frequency_penalty = json_value(body, "frequency_penalty", default_params.frequency_penalty); - llama.params.mirostat = json_value(body, "mirostat", default_params.mirostat); - llama.params.mirostat_tau = json_value(body, "mirostat_tau", default_params.mirostat_tau); - llama.params.mirostat_eta = json_value(body, "mirostat_eta", default_params.mirostat_eta); - llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl); + sparams.top_k = json_value(body, "top_k", default_sparams.top_k); + sparams.top_p = json_value(body, "top_p", default_sparams.top_p); + sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z); + sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p); + sparams.repeat_last_n = json_value(body, "repeat_last_n", default_sparams.repeat_last_n); + sparams.temp = json_value(body, "temperature", default_sparams.temp); + sparams.repeat_penalty = json_value(body, "repeat_penalty", default_sparams.repeat_penalty); + sparams.presence_penalty = json_value(body, "presence_penalty", default_sparams.presence_penalty); + sparams.frequency_penalty = json_value(body, "frequency_penalty", default_sparams.frequency_penalty); + sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat); + sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); + sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); + sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl); llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep); llama.params.seed = json_value(body, "seed", default_params.seed); llama.params.grammar = json_value(body, "grammar", default_params.grammar); - llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs); + sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs); if (body.count("prompt") != 0) { @@ -1161,10 +1167,10 @@ static void parse_options_completion(const json &body, llama_server_context &lla llama.prompt = ""; } - llama.params.logit_bias.clear(); + sparams.logit_bias.clear(); if (json_value(body, "ignore_eos", false)) { - llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY; + sparams.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY; } const auto &logit_bias = body.find("logit_bias"); @@ -1180,11 +1186,11 @@ static void parse_options_completion(const json &body, llama_server_context &lla { if (el[1].is_number()) { - llama.params.logit_bias[tok] = el[1].get(); + sparams.logit_bias[tok] = el[1].get(); } else if (el[1].is_boolean() && !el[1].get()) { - llama.params.logit_bias[tok] = -INFINITY; + sparams.logit_bias[tok] = -INFINITY; } } } @@ -1204,6 +1210,8 @@ static void parse_options_completion(const json &body, llama_server_context &lla } } + llama.sampling_state = llama_sampling_state_init(llama.params, llama.grammar); + LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); } @@ -1412,7 +1420,7 @@ int main(int argc, char **argv) } auto probs = llama.generated_token_probs; - if (llama.params.n_probs > 0 && llama.stopped_word) { + if (llama.params.sampling_params.n_probs > 0 && llama.stopped_word) { const std::vector stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false); probs = std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size()); } @@ -1464,7 +1472,7 @@ int main(int argc, char **argv) std::vector probs_output = {}; - if (llama.params.n_probs > 0) { + if (llama.params.sampling_params.n_probs > 0) { const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); @@ -1585,7 +1593,7 @@ int main(int argc, char **argv) std::vector probs_output = {}; - if (llama.params.n_probs > 0) { + if (llama.params.sampling_params.n_probs > 0) { const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 75a2e5e22d046..6a07c1c176f46 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -125,6 +125,8 @@ int main(int argc, char ** argv) { grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } + llama_sampling_state sampling_state = llama_sampling_state_init(params, grammar_tgt); + const auto t_dec_start = ggml_time_us(); while (true) { @@ -134,7 +136,7 @@ int main(int argc, char ** argv) { while (true) { // sample from the target model - llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); + llama_token id = llama_sample_token(ctx_tgt, NULL, sampling_state, last_tokens, candidates, i_dft); // remember which tokens were sampled - used for repetition penalties during sampling last_tokens.erase(last_tokens.begin()); @@ -211,7 +213,13 @@ int main(int argc, char ** argv) { if (grammar_dft) { llama_grammar_free(grammar_dft); } - grammar_dft = llama_grammar_copy(grammar_tgt); + // Note: Hardcoded to sequence id 0, if this ever supports parallel generation + // that will need to change. + auto it = sampling_state.sequence_states.find(0); + GGML_ASSERT(it != sampling_state.sequence_states.end()); + // This is necessary because each sequence id in sequence_states + // uses a copy of the original grammar. + grammar_dft = llama_grammar_copy(it->second.grammar); LOG("copied target grammar to draft grammar\n"); } From 52def09a3167cbb2cf762100f60e95c797732e6f Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Sun, 8 Oct 2023 09:39:27 -0600 Subject: [PATCH 3/6] Try to fix zig build. --- build.zig | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/build.zig b/build.zig index fdc5bc084eb46..0b74cee485320 100644 --- a/build.zig +++ b/build.zig @@ -128,17 +128,18 @@ pub fn build(b: *std.build.Builder) !void { const llama = make.obj("llama", "llama.cpp"); const common = make.obj("common", "common/common.cpp"); const console = make.obj("console", "common/console.cpp"); + const sampling = make.obj("sampling", "common/sampling.cpp"); const grammar_parser = make.obj("grammar-parser", "common/grammar-parser.cpp"); const train = make.obj("train", "common/train.cpp"); - _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, console, grammar_parser }); + _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, sampling, console, grammar_parser }); _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common }); _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common }); _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common }); _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, train }); _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, train }); - const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, grammar_parser }); + const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, sampling, grammar_parser }); if (server.target.isWindows()) { server.linkSystemLibrary("ws2_32"); } From 01bef0290020936baf90e9aa2c65c831894a77ea Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Sun, 8 Oct 2023 11:59:07 -0600 Subject: [PATCH 4/6] Export function to fetch/create default sampler states Code formatting cleanups and add some comments Silence a warning about id not being used when logging is disabled --- common/sampling.cpp | 18 ++++++++++++------ common/sampling.h | 29 +++++++++++++++++++++++++---- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 5e8ad1db43582..05751d91bdc2c 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -9,7 +9,9 @@ llama_sampling_state::~llama_sampling_state() { } } -llama_sampling_state llama_sampling_state_init(const struct gpt_params & params, llama_grammar * grammar) { +llama_sampling_state llama_sampling_state_init( + const struct gpt_params & params, + llama_grammar * grammar) { llama_sampling_state result; result.params = params.sampling_params; @@ -17,8 +19,10 @@ llama_sampling_state llama_sampling_state_init(const struct gpt_params & params, return result; } -// Creates the state if it doesn't exist, so this always return something. -static llama_sampler_sequence_state & sampling_get_sequence_state(llama_sampling_state & state, const llama_seq_id seq) { +// Note: Creates the state if it doesn't exist, so this always return something. +llama_sampler_sequence_state & llama_sampling_get_sequence_state( + llama_sampling_state & state, + const llama_seq_id seq) { const auto it = state.sequence_states.find(seq); if (it != state.sequence_states.end()) { return it->second; @@ -30,7 +34,9 @@ static llama_sampler_sequence_state & sampling_get_sequence_state(llama_sampling return state.sequence_states.insert({seq, new_state}).first->second; } -bool llama_sampling_state_reset(llama_sampling_state & state, const llama_seq_id seq) { +bool llama_sampling_state_reset( + llama_sampling_state & state, + const llama_seq_id seq) { const auto it = state.sequence_states.find(seq); if (it == state.sequence_states.end()) return false; if (it->second.grammar != NULL) { @@ -109,7 +115,7 @@ llama_token llama_sample_token( } } - llama_sampler_sequence_state & seq_state = sampling_get_sequence_state(state, seq); + llama_sampler_sequence_state & seq_state = llama_sampling_get_sequence_state(state, seq); if (seq_state.grammar != NULL) { llama_sample_grammar(ctx, &cur_p, seq_state.grammar); @@ -141,6 +147,7 @@ llama_token llama_sample_token( for (int i = 0; i < n_top; i++) { const llama_token id = cur_p.data[i].id; + (void)id; // To avoid a warning that id is unused when logging is disabled. LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p); } } @@ -150,7 +157,6 @@ llama_token llama_sample_token( LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str()); } } - // printf("`%d`", candidates_p.size); if (seq_state.grammar != NULL) { llama_grammar_accept_token(ctx, seq_state.grammar, id); diff --git a/common/sampling.h b/common/sampling.h index 6c5f8749e27e0..48702e2dd06de 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -44,19 +44,40 @@ typedef struct llama_sampler_sequence_state { typedef struct llama_sampling_state { ~llama_sampling_state(); + // parameters that will be used for sampling and when creating + // new llama_sampler_sequence_state instances llama_sampling_params params; + // map of sequence ids to sampler states std::unordered_map sequence_states; + // when non-NULL, new instances of llama_sampler_sequence_state + // will get a copy of the grammar here + // note: only the pointer is stored here, it is not a copy of + // the grammar and shouldn't be freed llama_grammar * grammar; } llama_sampling_state; #include "common.h" // Create a new sampling state instance. -llama_sampling_state llama_sampling_state_init(const struct gpt_params & params, llama_grammar * grammar); - -bool llama_sampling_state_reset(llama_sampling_state & state, const llama_seq_id seq = 0); +llama_sampling_state llama_sampling_state_init( + const struct gpt_params & params, + llama_grammar * grammar = NULL); + +// Fetches the sampler state for the specified sequence id (defaults to 0). +// If the state for that sequence id doesn't already exist, it will be created with +// default values based on the parameters in the state argument. +llama_sampler_sequence_state & llama_sampling_get_sequence_state( + llama_sampling_state & state, + const llama_seq_id seq = 0); + +// Reset the sampler states for the supplied sequence id (defaults to 0). +// This is necessary to reuse a sequence id or free memory used by sequences +// that are no longer required. +bool llama_sampling_state_reset( + llama_sampling_state & state, + const llama_seq_id seq = 0); // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function @@ -73,7 +94,7 @@ bool llama_sampling_state_reset(llama_sampling_state & state, const llama_seq_id // - grammar: grammar to use for sampling, ignore if NULL // - last_tokens: needed for repetition penalty, ignore if empty // - idx: sample from llama_get_logits_ith(ctx, idx) -// - seq: sequence id to associate sampler state with (currently only used by mirostat) +// - seq: sequence id to associate sampler state with // // returns: // - token: sampled token From 4a34e635007702358a409561c896cced88ce7886 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Wed, 11 Oct 2023 04:14:47 -0600 Subject: [PATCH 5/6] Apply some renaming suggestions. Fix comments that were out of sync with the pull. --- common/sampling.cpp | 60 ++++++++++++++-------------- common/sampling.h | 60 ++++++++++++++-------------- examples/infill/infill.cpp | 4 +- examples/main/main.cpp | 4 +- examples/parallel/parallel.cpp | 6 +-- examples/server/server.cpp | 10 ++--- examples/speculative/speculative.cpp | 10 ++--- 7 files changed, 76 insertions(+), 78 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 05751d91bdc2c..78cb038853c96 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,7 +1,7 @@ #include "sampling.h" -llama_sampling_state::~llama_sampling_state() { - for (auto & it : sequence_states) { +llama_sampling_context::~llama_sampling_context() { + for (auto & it : sequence_contexts) { if (it.second.grammar != NULL) { llama_grammar_free(it.second.grammar); it.second.grammar = NULL; @@ -9,48 +9,48 @@ llama_sampling_state::~llama_sampling_state() { } } -llama_sampling_state llama_sampling_state_init( +llama_sampling_context llama_sampling_context_init( const struct gpt_params & params, llama_grammar * grammar) { - llama_sampling_state result; + llama_sampling_context result; result.params = params.sampling_params; result.grammar = grammar; return result; } -// Note: Creates the state if it doesn't exist, so this always return something. -llama_sampler_sequence_state & llama_sampling_get_sequence_state( - llama_sampling_state & state, - const llama_seq_id seq) { - const auto it = state.sequence_states.find(seq); - if (it != state.sequence_states.end()) { +// Note: Creates the context if it doesn't exist, so this always return something. +llama_sampler_sequence_context & llama_sampling_get_sequence_context( + llama_sampling_context & sampling_ctx, + const llama_seq_id seq) { + const auto it = sampling_ctx.sequence_contexts.find(seq); + if (it != sampling_ctx.sequence_contexts.end()) { return it->second; } - llama_sampler_sequence_state new_state = { - 2.0f * state.params.mirostat_tau, - state.grammar != NULL ? llama_grammar_copy(state.grammar) : NULL, + llama_sampler_sequence_context new_ctx = { + 2.0f * sampling_ctx.params.mirostat_tau, + sampling_ctx.grammar != NULL ? llama_grammar_copy(sampling_ctx.grammar) : NULL, }; - return state.sequence_states.insert({seq, new_state}).first->second; + return sampling_ctx.sequence_contexts.insert({seq, new_ctx}).first->second; } -bool llama_sampling_state_reset( - llama_sampling_state & state, - const llama_seq_id seq) { - const auto it = state.sequence_states.find(seq); - if (it == state.sequence_states.end()) return false; +bool llama_sampling_context_reset( + llama_sampling_context & sampling_ctx, + const llama_seq_id seq) { + const auto it = sampling_ctx.sequence_contexts.find(seq); + if (it == sampling_ctx.sequence_contexts.end()) return false; if (it->second.grammar != NULL) { llama_grammar_free(it->second.grammar); it->second.grammar = NULL; } - state.sequence_states.erase(it); + sampling_ctx.sequence_contexts.erase(it); return true; } -llama_token llama_sample_token( +llama_token llama_sampling_sample( struct llama_context * ctx, struct llama_context * ctx_guidance, - struct llama_sampling_state & state, + struct llama_sampling_context & sampling_ctx, const std::vector & last_tokens, std::vector & candidates, const int idx, @@ -58,7 +58,7 @@ llama_token llama_sample_token( const int n_ctx = llama_n_ctx(ctx); const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - const llama_sampling_params & params = state.params; + const llama_sampling_params & params = sampling_ctx.params; const float temp = params.temp; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; const float top_p = params.top_p; @@ -115,10 +115,10 @@ llama_token llama_sample_token( } } - llama_sampler_sequence_state & seq_state = llama_sampling_get_sequence_state(state, seq); + llama_sampler_sequence_context & seq_ctx = llama_sampling_get_sequence_context(sampling_ctx, seq); - if (seq_state.grammar != NULL) { - llama_sample_grammar(ctx, &cur_p, seq_state.grammar); + if (seq_ctx.grammar != NULL) { + llama_sample_grammar(ctx, &cur_p, seq_ctx.grammar); } if (temp <= 0) { @@ -128,10 +128,10 @@ llama_token llama_sample_token( if (mirostat == 1) { const int mirostat_m = 100; llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &seq_state.mirostat_mu); + id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &seq_ctx.mirostat_mu); } else if (mirostat == 2) { llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &seq_state.mirostat_mu); + id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &seq_ctx.mirostat_mu); } else { // Temperature sampling size_t min_keep = std::max(1, params.n_probs); @@ -158,8 +158,8 @@ llama_token llama_sample_token( } } - if (seq_state.grammar != NULL) { - llama_grammar_accept_token(ctx, seq_state.grammar, id); + if (seq_ctx.grammar != NULL) { + llama_grammar_accept_token(ctx, seq_ctx.grammar, id); } return id; diff --git a/common/sampling.h b/common/sampling.h index 48702e2dd06de..96891640f76bc 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -34,64 +34,62 @@ typedef struct llama_sampling_params { } llama_sampling_params; -// per-sequence sampler state -typedef struct llama_sampler_sequence_state { +// per-sequence sampler context +typedef struct llama_sampler_sequence_context { float mirostat_mu; // mirostat sampler state llama_grammar * grammar; -} llama_sampler_sequence_state; +} llama_sampler_sequence_context; -// general sampler state -typedef struct llama_sampling_state { - ~llama_sampling_state(); +// general sampler context +typedef struct llama_sampling_context { + ~llama_sampling_context(); // parameters that will be used for sampling and when creating - // new llama_sampler_sequence_state instances + // new llama_sampler_sequence_context instances llama_sampling_params params; - // map of sequence ids to sampler states - std::unordered_map sequence_states; + // map of sequence ids to sampler contexts + std::unordered_map sequence_contexts; - // when non-NULL, new instances of llama_sampler_sequence_state + // when non-NULL, new instances of llama_sampler_sequence_context // will get a copy of the grammar here // note: only the pointer is stored here, it is not a copy of // the grammar and shouldn't be freed llama_grammar * grammar; -} llama_sampling_state; +} llama_sampling_context; #include "common.h" -// Create a new sampling state instance. -llama_sampling_state llama_sampling_state_init( +// Create a new sampling context instance. +llama_sampling_context llama_sampling_context_init( const struct gpt_params & params, llama_grammar * grammar = NULL); -// Fetches the sampler state for the specified sequence id (defaults to 0). -// If the state for that sequence id doesn't already exist, it will be created with -// default values based on the parameters in the state argument. -llama_sampler_sequence_state & llama_sampling_get_sequence_state( - llama_sampling_state & state, - const llama_seq_id seq = 0); +// Fetches the sampler context for the specified sequence id (defaults to 0). +// If the context for that sequence id doesn't already exist, it will be created with +// default values based on the parameters in the sampling_ctx argument. +llama_sampler_sequence_context & llama_sampling_get_sequence_context( + llama_sampling_context & sampling_ctx, + const llama_seq_id seq = 0); -// Reset the sampler states for the supplied sequence id (defaults to 0). +// Reset the sampler context for the supplied sequence id (defaults to 0). // This is necessary to reuse a sequence id or free memory used by sequences // that are no longer required. -bool llama_sampling_state_reset( - llama_sampling_state & state, - const llama_seq_id seq = 0); +bool llama_sampling_context_reset( + llama_sampling_context & sampling_ctx, + const llama_seq_id seq = 0); // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function -// Note: When using multiple sequences, it is the caller's responsibility to delete -// the item in params.sampler_state when a sequence ends and samplers that rely on -// state are being used. +// Note: When using multiple sequences, it is the caller's responsibility to call +// llama_sampling_context_reset when a sequence ends // // required: -// - ctx: context to use for sampling -// - params: sampling parameters +// - ctx: context to use for sampling +// - sampling_ctx: sampling-specific context // // optional: // - ctx_guidance: context to use for classifier-free guidance, ignore if NULL -// - grammar: grammar to use for sampling, ignore if NULL // - last_tokens: needed for repetition penalty, ignore if empty // - idx: sample from llama_get_logits_ith(ctx, idx) // - seq: sequence id to associate sampler state with @@ -100,10 +98,10 @@ bool llama_sampling_state_reset( // - token: sampled token // - candidates: vector of candidate tokens // -llama_token llama_sample_token( +llama_token llama_sampling_sample( struct llama_context * ctx, struct llama_context * ctx_guidance, - struct llama_sampling_state & state, + struct llama_sampling_context & sampling_ctx, const std::vector & last_tokens, std::vector & candidates, const int idx = 0, diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index cde82267874c2..cd9e6b14202f7 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -423,7 +423,7 @@ int main(int argc, char ** argv) { const int n_vocab = llama_n_vocab(model); - llama_sampling_state sampling_state = llama_sampling_state_init(params, grammar); + llama_sampling_context sampling_context = llama_sampling_context_init(params, grammar); std::vector candidates; candidates.reserve(n_vocab); @@ -542,7 +542,7 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sample_token(ctx, ctx_guidance, sampling_state, last_tokens, candidates); + const llama_token id = llama_sampling_sample(ctx, ctx_guidance, sampling_context, last_tokens, candidates); last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 89abfd55b0264..eaa99459adf4a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -470,7 +470,7 @@ int main(int argc, char ** argv) { const int n_vocab = llama_n_vocab(model); - llama_sampling_state sampling_state = llama_sampling_state_init(params, grammar); + llama_sampling_context sampling_context = llama_sampling_context_init(params, grammar); std::vector candidates; candidates.reserve(n_vocab); @@ -627,7 +627,7 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } - const llama_token id = llama_sample_token(ctx, ctx_guidance, sampling_state, last_tokens, candidates); + const llama_token id = llama_sampling_sample(ctx, ctx_guidance, sampling_context, last_tokens, candidates); last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index cb4c4ba96964e..50025a71c9614 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -125,7 +125,7 @@ int main(int argc, char ** argv) { params.logits_all = true; std::tie(model, ctx) = llama_init_from_gpt_params(params); - llama_sampling_state sampling_state = llama_sampling_state_init(params, NULL); + llama_sampling_context sampling_context = llama_sampling_context_init(params, NULL); // load the prompts from an external file if there are any if (params.prompt.empty()) { @@ -341,7 +341,7 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = llama_sample_token(ctx, NULL, sampling_state, client.tokens_prev, candidates, client.i_batch - i, client.seq_id); + const llama_token id = llama_sampling_sample(ctx, NULL, sampling_context, client.tokens_prev, candidates, client.i_batch - i, client.seq_id); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients @@ -386,7 +386,7 @@ int main(int argc, char ** argv) { n_total_prompt += client.n_prompt; n_total_gen += client.n_decoded; - llama_sampling_state_reset(sampling_state, client.seq_id); + llama_sampling_context_reset(sampling_context, client.seq_id); client.seq_id = -1; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 08bdaa1899d43..50e344a788fdb 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -200,7 +200,7 @@ struct llama_server_context llama_model *model = nullptr; llama_context *ctx = nullptr; gpt_params params; - llama_sampling_state sampling_state; + llama_sampling_context sampling_context; int n_ctx; grammar_parser::parse_state parsed_grammar; @@ -255,7 +255,7 @@ struct llama_server_context if (grammar != nullptr) { llama_grammar_free(grammar); grammar = nullptr; - sampling_state = llama_sampling_state_init(params, NULL); + sampling_context = llama_sampling_context_init(params, NULL); } } @@ -341,7 +341,7 @@ struct llama_server_context grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } - sampling_state = llama_sampling_state_init(params, grammar); + sampling_context = llama_sampling_context_init(params, grammar); return true; } @@ -542,7 +542,7 @@ struct llama_server_context std::vector candidates; candidates.reserve(llama_n_vocab(model)); - result.tok = llama_sample_token(ctx, NULL, sampling_state, last_n_tokens, candidates); + result.tok = llama_sampling_sample(ctx, NULL, sampling_context, last_n_tokens, candidates); llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; @@ -1210,7 +1210,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla } } - llama.sampling_state = llama_sampling_state_init(llama.params, llama.grammar); + llama.sampling_context = llama_sampling_context_init(llama.params, llama.grammar); LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 6a07c1c176f46..d5b19cf56375c 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -125,7 +125,7 @@ int main(int argc, char ** argv) { grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } - llama_sampling_state sampling_state = llama_sampling_state_init(params, grammar_tgt); + llama_sampling_context sampling_context = llama_sampling_context_init(params, grammar_tgt); const auto t_dec_start = ggml_time_us(); @@ -136,7 +136,7 @@ int main(int argc, char ** argv) { while (true) { // sample from the target model - llama_token id = llama_sample_token(ctx_tgt, NULL, sampling_state, last_tokens, candidates, i_dft); + llama_token id = llama_sampling_sample(ctx_tgt, NULL, sampling_context, last_tokens, candidates, i_dft); // remember which tokens were sampled - used for repetition penalties during sampling last_tokens.erase(last_tokens.begin()); @@ -215,9 +215,9 @@ int main(int argc, char ** argv) { } // Note: Hardcoded to sequence id 0, if this ever supports parallel generation // that will need to change. - auto it = sampling_state.sequence_states.find(0); - GGML_ASSERT(it != sampling_state.sequence_states.end()); - // This is necessary because each sequence id in sequence_states + auto it = sampling_context.sequence_contexts.find(0); + GGML_ASSERT(it != sampling_context.sequence_contexts.end()); + // This is necessary because each sequence id in sequence_contexts // uses a copy of the original grammar. grammar_dft = llama_grammar_copy(it->second.grammar); From fffa4c00992b12bf91e8dd1b1e1eb117f2538e78 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Wed, 11 Oct 2023 11:48:23 -0600 Subject: [PATCH 6/6] Use more consistant naming convention for sampling contexts --- common/sampling.cpp | 38 ++++++++++++++-------------- common/sampling.h | 10 ++++---- examples/infill/infill.cpp | 4 +-- examples/main/main.cpp | 4 +-- examples/parallel/parallel.cpp | 6 ++--- examples/server/server.cpp | 10 ++++---- examples/speculative/speculative.cpp | 8 +++--- 7 files changed, 40 insertions(+), 40 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 78cb038853c96..8ce4194593ca7 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -21,36 +21,36 @@ llama_sampling_context llama_sampling_context_init( // Note: Creates the context if it doesn't exist, so this always return something. llama_sampler_sequence_context & llama_sampling_get_sequence_context( - llama_sampling_context & sampling_ctx, + llama_sampling_context & ctx_sampling, const llama_seq_id seq) { - const auto it = sampling_ctx.sequence_contexts.find(seq); - if (it != sampling_ctx.sequence_contexts.end()) { + const auto it = ctx_sampling.sequence_contexts.find(seq); + if (it != ctx_sampling.sequence_contexts.end()) { return it->second; } llama_sampler_sequence_context new_ctx = { - 2.0f * sampling_ctx.params.mirostat_tau, - sampling_ctx.grammar != NULL ? llama_grammar_copy(sampling_ctx.grammar) : NULL, + 2.0f * ctx_sampling.params.mirostat_tau, + ctx_sampling.grammar != NULL ? llama_grammar_copy(ctx_sampling.grammar) : NULL, }; - return sampling_ctx.sequence_contexts.insert({seq, new_ctx}).first->second; + return ctx_sampling.sequence_contexts.insert({seq, new_ctx}).first->second; } bool llama_sampling_context_reset( - llama_sampling_context & sampling_ctx, + llama_sampling_context & ctx_sampling, const llama_seq_id seq) { - const auto it = sampling_ctx.sequence_contexts.find(seq); - if (it == sampling_ctx.sequence_contexts.end()) return false; + const auto it = ctx_sampling.sequence_contexts.find(seq); + if (it == ctx_sampling.sequence_contexts.end()) return false; if (it->second.grammar != NULL) { llama_grammar_free(it->second.grammar); it->second.grammar = NULL; } - sampling_ctx.sequence_contexts.erase(it); + ctx_sampling.sequence_contexts.erase(it); return true; } llama_token llama_sampling_sample( struct llama_context * ctx, struct llama_context * ctx_guidance, - struct llama_sampling_context & sampling_ctx, + struct llama_sampling_context & ctx_sampling, const std::vector & last_tokens, std::vector & candidates, const int idx, @@ -58,7 +58,7 @@ llama_token llama_sampling_sample( const int n_ctx = llama_n_ctx(ctx); const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - const llama_sampling_params & params = sampling_ctx.params; + const llama_sampling_params & params = ctx_sampling.params; const float temp = params.temp; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; const float top_p = params.top_p; @@ -115,10 +115,10 @@ llama_token llama_sampling_sample( } } - llama_sampler_sequence_context & seq_ctx = llama_sampling_get_sequence_context(sampling_ctx, seq); + llama_sampler_sequence_context & ctx_seq = llama_sampling_get_sequence_context(ctx_sampling, seq); - if (seq_ctx.grammar != NULL) { - llama_sample_grammar(ctx, &cur_p, seq_ctx.grammar); + if (ctx_seq.grammar != NULL) { + llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar); } if (temp <= 0) { @@ -128,10 +128,10 @@ llama_token llama_sampling_sample( if (mirostat == 1) { const int mirostat_m = 100; llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &seq_ctx.mirostat_mu); + id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_seq.mirostat_mu); } else if (mirostat == 2) { llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &seq_ctx.mirostat_mu); + id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_seq.mirostat_mu); } else { // Temperature sampling size_t min_keep = std::max(1, params.n_probs); @@ -158,8 +158,8 @@ llama_token llama_sampling_sample( } } - if (seq_ctx.grammar != NULL) { - llama_grammar_accept_token(ctx, seq_ctx.grammar, id); + if (ctx_seq.grammar != NULL) { + llama_grammar_accept_token(ctx, ctx_seq.grammar, id); } return id; diff --git a/common/sampling.h b/common/sampling.h index 96891640f76bc..0aab5d03c2f61 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -67,16 +67,16 @@ llama_sampling_context llama_sampling_context_init( // Fetches the sampler context for the specified sequence id (defaults to 0). // If the context for that sequence id doesn't already exist, it will be created with -// default values based on the parameters in the sampling_ctx argument. +// default values based on the parameters in the ctx_sampling argument. llama_sampler_sequence_context & llama_sampling_get_sequence_context( - llama_sampling_context & sampling_ctx, + llama_sampling_context & ctx_sampling, const llama_seq_id seq = 0); // Reset the sampler context for the supplied sequence id (defaults to 0). // This is necessary to reuse a sequence id or free memory used by sequences // that are no longer required. bool llama_sampling_context_reset( - llama_sampling_context & sampling_ctx, + llama_sampling_context & ctx_sampling, const llama_seq_id seq = 0); // this is a common sampling function used across the examples for convenience @@ -86,7 +86,7 @@ bool llama_sampling_context_reset( // // required: // - ctx: context to use for sampling -// - sampling_ctx: sampling-specific context +// - ctx_sampling: sampling-specific context // // optional: // - ctx_guidance: context to use for classifier-free guidance, ignore if NULL @@ -101,7 +101,7 @@ bool llama_sampling_context_reset( llama_token llama_sampling_sample( struct llama_context * ctx, struct llama_context * ctx_guidance, - struct llama_sampling_context & sampling_ctx, + struct llama_sampling_context & ctx_sampling, const std::vector & last_tokens, std::vector & candidates, const int idx = 0, diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index cd9e6b14202f7..525a3beeef4c4 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -423,7 +423,7 @@ int main(int argc, char ** argv) { const int n_vocab = llama_n_vocab(model); - llama_sampling_context sampling_context = llama_sampling_context_init(params, grammar); + llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar); std::vector candidates; candidates.reserve(n_vocab); @@ -542,7 +542,7 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sampling_sample(ctx, ctx_guidance, sampling_context, last_tokens, candidates); + const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates); last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index eaa99459adf4a..b39a67d979c88 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -470,7 +470,7 @@ int main(int argc, char ** argv) { const int n_vocab = llama_n_vocab(model); - llama_sampling_context sampling_context = llama_sampling_context_init(params, grammar); + llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar); std::vector candidates; candidates.reserve(n_vocab); @@ -627,7 +627,7 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } - const llama_token id = llama_sampling_sample(ctx, ctx_guidance, sampling_context, last_tokens, candidates); + const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates); last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 50025a71c9614..cdb198a7c10c1 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -125,7 +125,7 @@ int main(int argc, char ** argv) { params.logits_all = true; std::tie(model, ctx) = llama_init_from_gpt_params(params); - llama_sampling_context sampling_context = llama_sampling_context_init(params, NULL); + llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL); // load the prompts from an external file if there are any if (params.prompt.empty()) { @@ -341,7 +341,7 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = llama_sampling_sample(ctx, NULL, sampling_context, client.tokens_prev, candidates, client.i_batch - i, client.seq_id); + const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients @@ -386,7 +386,7 @@ int main(int argc, char ** argv) { n_total_prompt += client.n_prompt; n_total_gen += client.n_decoded; - llama_sampling_context_reset(sampling_context, client.seq_id); + llama_sampling_context_reset(ctx_sampling, client.seq_id); client.seq_id = -1; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 50e344a788fdb..e906a17bf2e42 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -200,7 +200,7 @@ struct llama_server_context llama_model *model = nullptr; llama_context *ctx = nullptr; gpt_params params; - llama_sampling_context sampling_context; + llama_sampling_context ctx_sampling; int n_ctx; grammar_parser::parse_state parsed_grammar; @@ -255,7 +255,7 @@ struct llama_server_context if (grammar != nullptr) { llama_grammar_free(grammar); grammar = nullptr; - sampling_context = llama_sampling_context_init(params, NULL); + ctx_sampling = llama_sampling_context_init(params, NULL); } } @@ -341,7 +341,7 @@ struct llama_server_context grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } - sampling_context = llama_sampling_context_init(params, grammar); + ctx_sampling = llama_sampling_context_init(params, grammar); return true; } @@ -542,7 +542,7 @@ struct llama_server_context std::vector candidates; candidates.reserve(llama_n_vocab(model)); - result.tok = llama_sampling_sample(ctx, NULL, sampling_context, last_n_tokens, candidates); + result.tok = llama_sampling_sample(ctx, NULL, ctx_sampling, last_n_tokens, candidates); llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; @@ -1210,7 +1210,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla } } - llama.sampling_context = llama_sampling_context_init(llama.params, llama.grammar); + llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar); LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index d5b19cf56375c..018dbf9a205b9 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -125,7 +125,7 @@ int main(int argc, char ** argv) { grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } - llama_sampling_context sampling_context = llama_sampling_context_init(params, grammar_tgt); + llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar_tgt); const auto t_dec_start = ggml_time_us(); @@ -136,7 +136,7 @@ int main(int argc, char ** argv) { while (true) { // sample from the target model - llama_token id = llama_sampling_sample(ctx_tgt, NULL, sampling_context, last_tokens, candidates, i_dft); + llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, i_dft); // remember which tokens were sampled - used for repetition penalties during sampling last_tokens.erase(last_tokens.begin()); @@ -215,8 +215,8 @@ int main(int argc, char ** argv) { } // Note: Hardcoded to sequence id 0, if this ever supports parallel generation // that will need to change. - auto it = sampling_context.sequence_contexts.find(0); - GGML_ASSERT(it != sampling_context.sequence_contexts.end()); + auto it = ctx_sampling.sequence_contexts.find(0); + GGML_ASSERT(it != ctx_sampling.sequence_contexts.end()); // This is necessary because each sequence id in sequence_contexts // uses a copy of the original grammar. grammar_dft = llama_grammar_copy(it->second.grammar);