From a4ab8e5d834f622c6f652c790caebcf18aced431 Mon Sep 17 00:00:00 2001
From: Marcus Dunn <marcus.s.dunn@gmail.com>
Date: Sat, 21 Oct 2023 15:59:54 -0700
Subject: [PATCH 1/9] added `llama_model_token_*` variants to all the
 `llama_token_*` functions.

---
 llama.cpp | 28 ++++++++++++++++++++++++++++
 llama.h   | 10 ++++++++++
 2 files changed, 38 insertions(+)

diff --git a/llama.cpp b/llama.cpp
index 3653493355234..473ec2a3cc218 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -9669,29 +9669,57 @@ llama_token llama_token_bos(const struct llama_context * ctx) {
     return ctx->model.vocab.special_bos_id;
 }
 
+llama_token llama_model_token_bos(const struct llama_model * model) {
+    return model->vocab.special_bos_id;
+}
+
 llama_token llama_token_eos(const struct llama_context * ctx) {
     return ctx->model.vocab.special_eos_id;
 }
 
+llama_token llama_model_token_eos(const struct llama_model * model) {
+    return model->vocab.special_eos_id;
+}
+
 llama_token llama_token_nl(const struct llama_context * ctx) {
     return ctx->model.vocab.linefeed_id;
 }
+
+llama_token llama_model_token_nl(const struct llama_model * model) {
+    return model->vocab.linefeed_id;
+}
 llama_token llama_token_prefix(const struct llama_context * ctx) {
     return ctx->model.vocab.special_prefix_id;
 }
 
+llama_token llama_model_token_prefix(const struct llama_model * model) {
+    return model->vocab.special_prefix_id;
+}
+
 llama_token llama_token_middle(const struct llama_context * ctx) {
     return ctx->model.vocab.special_middle_id;
 }
 
+llama_token llama_model_token_middle(const struct llama_model * model) {
+    return model->vocab.special_middle_id;
+}
+
 llama_token llama_token_suffix(const struct llama_context * ctx) {
     return ctx->model.vocab.special_suffix_id;
 }
 
+llama_token llama_model_token_suffix(const struct llama_model * model) {
+    return model->vocab.special_suffix_id;
+}
+
 llama_token llama_token_eot(const struct llama_context * ctx) {
     return ctx->model.vocab.special_eot_id;
 }
 
+llama_token llama_model_token_eot(const struct llama_model * model) {
+    return model->vocab.special_eot_id;
+}
+
 int llama_tokenize(
     const struct llama_model * model,
                   const char * text,
diff --git a/llama.h b/llama.h
index 306f5b383cb11..f30c01f990a90 100644
--- a/llama.h
+++ b/llama.h
@@ -504,12 +504,22 @@ extern "C" {
     LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx);  // beginning-of-sentence
     LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx);  // end-of-sentence
     LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx);  // next-line
+
+    LLAMA_API llama_token llama_model_token_bos(const struct llama_model *model);
+    LLAMA_API llama_token llama_model_token_eos(const struct llama_model *model);
+    LLAMA_API llama_token llama_model_token_nl(const struct llama_model *model);
+
     // codellama infill tokens
     LLAMA_API llama_token llama_token_prefix(const struct llama_context * ctx); // Beginning of infill prefix
     LLAMA_API llama_token llama_token_middle(const struct llama_context * ctx); // Beginning of infill middle
     LLAMA_API llama_token llama_token_suffix(const struct llama_context * ctx); // Beginning of infill suffix
     LLAMA_API llama_token llama_token_eot   (const struct llama_context * ctx); // End of infill middle
 
+    llama_token llama_model_token_prefix(const struct llama_model * model);
+    llama_token llama_model_token_middle(const struct llama_model * model);
+    llama_token llama_model_token_suffix(const struct llama_model * model);
+    llama_token llama_model_token_eot   (const struct llama_model * model);
+
     //
     // Tokenization
     //

From 7b127a734dab657701cfcc00c43cca3083048ea0 Mon Sep 17 00:00:00 2001
From: Marcus Dunn <marcus.s.dunn@gmail.com>
Date: Sat, 21 Oct 2023 16:07:26 -0700
Subject: [PATCH 2/9] added `LLAMA_API`

---
 llama.h | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/llama.h b/llama.h
index f30c01f990a90..dc633580f15d1 100644
--- a/llama.h
+++ b/llama.h
@@ -515,10 +515,10 @@ extern "C" {
     LLAMA_API llama_token llama_token_suffix(const struct llama_context * ctx); // Beginning of infill suffix
     LLAMA_API llama_token llama_token_eot   (const struct llama_context * ctx); // End of infill middle
 
-    llama_token llama_model_token_prefix(const struct llama_model * model);
-    llama_token llama_model_token_middle(const struct llama_model * model);
-    llama_token llama_model_token_suffix(const struct llama_model * model);
-    llama_token llama_model_token_eot   (const struct llama_model * model);
+    LLAMA_API llama_token llama_model_token_prefix(const struct llama_model * model);
+    LLAMA_API llama_token llama_model_token_middle(const struct llama_model * model);
+    LLAMA_API llama_token llama_model_token_suffix(const struct llama_model * model);
+    LLAMA_API llama_token llama_model_token_eot   (const struct llama_model * model);
 
     //
     // Tokenization

From 353f4ef717d3663038721e3aa22fb0a69c5b9113 Mon Sep 17 00:00:00 2001
From: Marcus Dunn <51931484+MarcusDunn@users.noreply.github.com>
Date: Mon, 23 Oct 2023 09:08:46 -0700
Subject: [PATCH 3/9] formatting

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
---
 llama.h | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llama.h b/llama.h
index dc633580f15d1..b6c99665c1b53 100644
--- a/llama.h
+++ b/llama.h
@@ -505,9 +505,9 @@ extern "C" {
     LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx);  // end-of-sentence
     LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx);  // next-line
 
-    LLAMA_API llama_token llama_model_token_bos(const struct llama_model *model);
-    LLAMA_API llama_token llama_model_token_eos(const struct llama_model *model);
-    LLAMA_API llama_token llama_model_token_nl(const struct llama_model *model);
+    LLAMA_API llama_token llama_model_token_bos(const struct llama_model * model);
+    LLAMA_API llama_token llama_model_token_eos(const struct llama_model * model);
+    LLAMA_API llama_token llama_model_token_nl (const struct llama_model * model);
 
     // codellama infill tokens
     LLAMA_API llama_token llama_token_prefix(const struct llama_context * ctx); // Beginning of infill prefix

From 22d5eb41bb4f0908061211f023cf5e0072bc3df7 Mon Sep 17 00:00:00 2001
From: Marcus Dunn <marcus.s.dunn@gmail.com>
Date: Mon, 23 Oct 2023 09:15:48 -0700
Subject: [PATCH 4/9] removed old `llama_token` functions

---
 common/common.cpp   |  8 ++++----
 common/sampling.cpp |  4 ++--
 llama.cpp           | 47 ++++++++++-----------------------------------
 llama.h             | 23 +++++++---------------
 4 files changed, 23 insertions(+), 59 deletions(-)

diff --git a/common/common.cpp b/common/common.cpp
index 2ef902bd504c4..a9bb4ebb9f74e 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -879,13 +879,13 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
     }
 
     if (params.ignore_eos) {
-        params.sparams.logit_bias[llama_token_eos(lctx)] = -INFINITY;
+        params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
     }
 
     {
         LOG("warming up the model with an empty run\n");
 
-        std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
+        std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
         llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
         llama_kv_cache_tokens_rm(lctx, -1, -1);
         llama_reset_timings(lctx);
@@ -940,7 +940,7 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t
 }
 
 std::string llama_detokenize_spm(llama_context * ctx, const std::vector<llama_token> & tokens) {
-    const llama_token bos_id = llama_token_bos(ctx);
+    const llama_token bos_id = llama_token_bos(llama_get_model(ctx));
 
     std::string piece;
     std::string result;
@@ -1185,7 +1185,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
     fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
     fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
 
-    const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(lctx));
+    const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
     const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
     fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
 
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 6f0af3c4a2afd..5258d4e826369 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -147,7 +147,7 @@ llama_token llama_sampling_sample(
 
     // apply penalties
     if (!prev.empty()) {
-        const float nl_logit = logits[llama_token_nl(ctx_main)];
+        const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
 
         llama_sample_repetition_penalties(ctx_main, &cur_p,
                 prev.data() + prev.size() - penalty_last_n,
@@ -155,7 +155,7 @@ llama_token llama_sampling_sample(
 
         if (!penalize_nl) {
             for (size_t idx = 0; idx < cur_p.size; idx++) {
-                if (cur_p.data[idx].id == llama_token_nl(ctx_main)) {
+                if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
                     cur_p.data[idx].logit = nl_logit;
                     break;
                 }
diff --git a/llama.cpp b/llama.cpp
index 473ec2a3cc218..826addd08e568 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -7473,7 +7473,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
         }
     }
 
-    const llama_token eos = llama_token_eos(ctx);
+    const llama_token eos = llama_token_eos(&ctx->model);
 
     std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
     std::vector<llama_grammar_candidate>                              candidates_grammar;
@@ -7683,7 +7683,7 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
 void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
     const int64_t t_start_sample_us = ggml_time_us();
 
-    if (token == llama_token_eos(ctx)) {
+    if (token == llama_token_eos(&ctx->model)) {
         for (const auto & stack : grammar->stacks) {
             if (stack.empty()) {
                 return;
@@ -8892,7 +8892,7 @@ struct llama_context * llama_new_context_with_model(
             // build worst-case graph
             int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch);
             int n_past = cparams.n_ctx - n_tokens;
-            llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
+            llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
             ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0));
 
 #ifdef GGML_USE_METAL
@@ -9665,58 +9665,31 @@ llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_to
     return ctx->model.vocab.id_to_token[token].type;
 }
 
-llama_token llama_token_bos(const struct llama_context * ctx) {
-    return ctx->model.vocab.special_bos_id;
-}
-
-llama_token llama_model_token_bos(const struct llama_model * model) {
+llama_token llama_token_bos(const struct llama_model * model) {
     return model->vocab.special_bos_id;
 }
 
-llama_token llama_token_eos(const struct llama_context * ctx) {
-    return ctx->model.vocab.special_eos_id;
-}
-
-llama_token llama_model_token_eos(const struct llama_model * model) {
+llama_token llama_token_eos(const struct llama_model * model) {
     return model->vocab.special_eos_id;
 }
 
-llama_token llama_token_nl(const struct llama_context * ctx) {
-    return ctx->model.vocab.linefeed_id;
-}
-
-llama_token llama_model_token_nl(const struct llama_model * model) {
+llama_token llama_token_nl(const struct llama_model * model) {
     return model->vocab.linefeed_id;
 }
-llama_token llama_token_prefix(const struct llama_context * ctx) {
-    return ctx->model.vocab.special_prefix_id;
-}
 
-llama_token llama_model_token_prefix(const struct llama_model * model) {
+llama_token llama_token_prefix(const struct llama_model * model) {
     return model->vocab.special_prefix_id;
 }
 
-llama_token llama_token_middle(const struct llama_context * ctx) {
-    return ctx->model.vocab.special_middle_id;
-}
-
-llama_token llama_model_token_middle(const struct llama_model * model) {
+llama_token llama_token_middle(const struct llama_model * model) {
     return model->vocab.special_middle_id;
 }
 
-llama_token llama_token_suffix(const struct llama_context * ctx) {
-    return ctx->model.vocab.special_suffix_id;
-}
-
-llama_token llama_model_token_suffix(const struct llama_model * model) {
+llama_token llama_token_suffix(const struct llama_model * model) {
     return model->vocab.special_suffix_id;
 }
 
-llama_token llama_token_eot(const struct llama_context * ctx) {
-    return ctx->model.vocab.special_eot_id;
-}
-
-llama_token llama_model_token_eot(const struct llama_model * model) {
+llama_token llama_token_eot(const struct llama_model * model) {
     return model->vocab.special_eot_id;
 }
 
diff --git a/llama.h b/llama.h
index b6c99665c1b53..671e92a0efe6d 100644
--- a/llama.h
+++ b/llama.h
@@ -501,24 +501,15 @@ extern "C" {
     LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token);
 
     // Special tokens
-    LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx);  // beginning-of-sentence
-    LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx);  // end-of-sentence
-    LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx);  // next-line
-
-    LLAMA_API llama_token llama_model_token_bos(const struct llama_model * model);
-    LLAMA_API llama_token llama_model_token_eos(const struct llama_model * model);
-    LLAMA_API llama_token llama_model_token_nl (const struct llama_model * model);
+    LLAMA_API llama_token llama_token_bos(const struct llama_model * model);
+    LLAMA_API llama_token llama_token_eos(const struct llama_model * model);
+    LLAMA_API llama_token llama_token_nl (const struct llama_model * model);
 
     // codellama infill tokens
-    LLAMA_API llama_token llama_token_prefix(const struct llama_context * ctx); // Beginning of infill prefix
-    LLAMA_API llama_token llama_token_middle(const struct llama_context * ctx); // Beginning of infill middle
-    LLAMA_API llama_token llama_token_suffix(const struct llama_context * ctx); // Beginning of infill suffix
-    LLAMA_API llama_token llama_token_eot   (const struct llama_context * ctx); // End of infill middle
-
-    LLAMA_API llama_token llama_model_token_prefix(const struct llama_model * model);
-    LLAMA_API llama_token llama_model_token_middle(const struct llama_model * model);
-    LLAMA_API llama_token llama_model_token_suffix(const struct llama_model * model);
-    LLAMA_API llama_token llama_model_token_eot   (const struct llama_model * model);
+    LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
+    LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
+    LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
+    LLAMA_API llama_token llama_token_eot   (const struct llama_model * model); // End of infill middle
 
     //
     // Tokenization

From a550b23e3a59f8bbeb29986ebc16814799b2f204 Mon Sep 17 00:00:00 2001
From: Marcus Dunn <marcus.s.dunn@gmail.com>
Date: Mon, 23 Oct 2023 09:17:55 -0700
Subject: [PATCH 5/9] changed 3 more functions to take in model

- `llama_token_get_text`
- `llama_token_get_score`
- `llama_token_get_type`
---
 llama.cpp | 12 ++++++------
 llama.h   |  6 +++---
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/llama.cpp b/llama.cpp
index 826addd08e568..dfcaaeac5f08e 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -9653,16 +9653,16 @@ float * llama_get_embeddings(struct llama_context * ctx) {
     return ctx->embedding.data();
 }
 
-const char * llama_token_get_text(const struct llama_context * ctx, llama_token token) {
-    return ctx->model.vocab.id_to_token[token].text.c_str();
+const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
+    return model->vocab.id_to_token[token].text.c_str();
 }
 
-float llama_token_get_score(const struct llama_context * ctx, llama_token token) {
-    return ctx->model.vocab.id_to_token[token].score;
+float llama_token_get_score(const struct llama_model * model, llama_token token) {
+    return model->vocab.id_to_token[token].score;
 }
 
-llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token) {
-    return ctx->model.vocab.id_to_token[token].type;
+llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) {
+    return model->vocab.id_to_token[token].type;
 }
 
 llama_token llama_token_bos(const struct llama_model * model) {
diff --git a/llama.h b/llama.h
index 671e92a0efe6d..d01aef8bd3773 100644
--- a/llama.h
+++ b/llama.h
@@ -494,11 +494,11 @@ extern "C" {
     // Vocab
     //
 
-    LLAMA_API const char * llama_token_get_text(const struct llama_context * ctx, llama_token token);
+    LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
 
-    LLAMA_API float llama_token_get_score(const struct llama_context * ctx, llama_token token);
+    LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
 
-    LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token);
+    LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
 
     // Special tokens
     LLAMA_API llama_token llama_token_bos(const struct llama_model * model);

From 4646c9dadd4855835c2b0e2e72a66cf0af2120b3 Mon Sep 17 00:00:00 2001
From: Marcus Dunn <marcus.s.dunn@gmail.com>
Date: Mon, 23 Oct 2023 09:25:05 -0700
Subject: [PATCH 6/9] added back docs

---
 llama.h | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llama.h b/llama.h
index d01aef8bd3773..2f2fee0e2ff9f 100644
--- a/llama.h
+++ b/llama.h
@@ -501,9 +501,9 @@ extern "C" {
     LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
 
     // Special tokens
-    LLAMA_API llama_token llama_token_bos(const struct llama_model * model);
-    LLAMA_API llama_token llama_token_eos(const struct llama_model * model);
-    LLAMA_API llama_token llama_token_nl (const struct llama_model * model);
+    LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
+    LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
+    LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
 
     // codellama infill tokens
     LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix

From 38cdb8223516603c73ef8aecff74021c2c46191e Mon Sep 17 00:00:00 2001
From: Marcus Dunn <marcus.s.dunn@gmail.com>
Date: Mon, 23 Oct 2023 09:28:07 -0700
Subject: [PATCH 7/9] fixed main.cpp

---
 examples/main/main.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 2621bd539875f..3d9f670b9da7f 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -248,7 +248,7 @@ int main(int argc, char ** argv) {
 
     // Should not run without any tokens
     if (embd_inp.empty()) {
-        embd_inp.push_back(llama_token_bos(ctx));
+        embd_inp.push_back(llama_token_bos(model));
         LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
     }
 
@@ -693,7 +693,7 @@ int main(int argc, char ** argv) {
             }
 
             // deal with end of text token in interactive mode
-            if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
+            if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) {
                 LOG("found EOS token\n");
 
                 if (params.interactive) {
@@ -720,7 +720,7 @@ int main(int argc, char ** argv) {
 
                 if (params.input_prefix_bos) {
                     LOG("adding input prefix BOS token\n");
-                    embd_inp.push_back(llama_token_bos(ctx));
+                    embd_inp.push_back(llama_token_bos(model));
                 }
 
                 std::string buffer;
@@ -804,7 +804,7 @@ int main(int argc, char ** argv) {
         }
 
         // end of text token
-        if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !(params.instruct || params.interactive)) {
+        if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive)) {
             LOG_TEE(" [end of text]\n");
             break;
         }

From 2df380170602d49040a99209a07bbb0bcebaa4f3 Mon Sep 17 00:00:00 2001
From: Marcus Dunn <marcus.s.dunn@gmail.com>
Date: Mon, 23 Oct 2023 09:29:11 -0700
Subject: [PATCH 8/9] changed token functions to use new model variants

---
 examples/perplexity/perplexity.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index 7d0038bd40757..3c2542e8c105e 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -227,7 +227,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
 
             // add BOS token for the first batch of each chunk
             if (add_bos && j == 0) {
-                tokens[batch_start] = llama_token_bos(ctx);
+                tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
             }
 
             const auto batch_logits = llama_get_logits(ctx);
@@ -350,7 +350,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
 
             // add BOS token for the first batch of each chunk
             if (add_bos && j == 0) {
-                tokens[batch_start] = llama_token_bos(ctx);
+                tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
             }
 
             if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {

From fc5bb855458503a5dde4c1f39fa459c2c38950db Mon Sep 17 00:00:00 2001
From: Marcus Dunn <marcus.s.dunn@gmail.com>
Date: Mon, 23 Oct 2023 09:30:00 -0700
Subject: [PATCH 9/9] changed token functions to use new model variants

---
 common/train.cpp                     |  6 +++---
 examples/batched/batched.cpp         |  2 +-
 examples/beam-search/beam-search.cpp |  2 +-
 examples/infill/infill.cpp           | 30 ++++++++++++++--------------
 examples/llama-bench/llama-bench.cpp |  4 ++--
 examples/llava/llava-utils.h         |  2 +-
 examples/parallel/parallel.cpp       |  2 +-
 examples/server/server.cpp           | 14 ++++++-------
 examples/simple/simple.cpp           |  2 +-
 examples/speculative/speculative.cpp |  2 +-
 10 files changed, 33 insertions(+), 33 deletions(-)

diff --git a/common/train.cpp b/common/train.cpp
index 154ca56e5fa87..3cce5da269637 100644
--- a/common/train.cpp
+++ b/common/train.cpp
@@ -236,8 +236,8 @@ int64_t get_example_targets_batch(
     int64_t used_samples = 0;
 
     ggml_set_f32(target_probs, 0.0f);
-    llama_token bos = llama_token_bos(lctx);
-    llama_token eos = llama_token_eos(lctx);
+    llama_token bos = llama_token_bos(llama_get_model(lctx));
+    llama_token eos = llama_token_eos(llama_get_model(lctx));
     // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
     for (int k=0; k<n_batch; ++k) {
         // printf("%s: batch %d\n", __func__, k);
@@ -924,7 +924,7 @@ size_t tokenize_file(
         for (llama_token token=0; token < n_vocab; ++token) {
             max_token_text_size = std::max(
                 max_token_text_size,
-                strlen(llama_token_get_text(lctx, token)));
+                strlen(llama_token_get_text(llama_get_model(lctx), token)));
         }
 
         // upper bound of context byte length.
diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp
index 2797329b4fc57..75856a81fe9b1 100644
--- a/examples/batched/batched.cpp
+++ b/examples/batched/batched.cpp
@@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
             //const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
 
             // is it an end of stream? -> mark the stream as finished
-            if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
+            if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
                 i_batch[i] = -1;
                 LOG_TEE("\n");
                 if (n_parallel > 1) {
diff --git a/examples/beam-search/beam-search.cpp b/examples/beam-search/beam-search.cpp
index f078ab8a87fa5..679b382e19b4e 100644
--- a/examples/beam-search/beam-search.cpp
+++ b/examples/beam-search/beam-search.cpp
@@ -47,7 +47,7 @@ struct beam_search_callback_data {
 // In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same.
 // For example, eob can be flagged due to maximum token length, stop words, etc.
 static bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, size_t n_tokens) {
-    return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx);
+    return n_tokens && tokens[n_tokens-1] == llama_token_eos(llama_get_model(callback_data.ctx));
 }
 
 // Function matching type llama_beam_search_callback_fn_t.
diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp
index 6331335e3c423..9c52b7bbad1db 100644
--- a/examples/infill/infill.cpp
+++ b/examples/infill/infill.cpp
@@ -246,14 +246,14 @@ int main(int argc, char ** argv) {
     if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
         inp_sfx.erase(inp_sfx.begin());
     }
-    inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
+    inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
     if (add_bos) {
-        inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx));
+        inp_pfx.insert(inp_pfx.begin(), llama_token_bos(model));
     }
-    inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
+    inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
     embd_inp = inp_pfx;
     embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
-    embd_inp.push_back(llama_token_middle(ctx));
+    embd_inp.push_back(llama_token_middle(model));
 
     LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
     LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
@@ -261,7 +261,7 @@ int main(int argc, char ** argv) {
 
     // Should not run without any tokens
     if (embd_inp.empty()) {
-        embd_inp.push_back(llama_token_bos(ctx));
+        embd_inp.push_back(llama_token_bos(model));
         LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
     }
 
@@ -577,10 +577,10 @@ int main(int argc, char ** argv) {
         if ((int) embd_inp.size() <= n_consumed) {
 
             // deal with eot token in infill mode
-            if ((llama_sampling_last(ctx_sampling) == llama_token_eot(ctx) || is_interacting) && params.interactive){
+            if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
                 if(is_interacting && !params.interactive_first) {
                     // print an eot token
-                    printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
+                    printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
                 }
                 fflush(stdout);
                 printf("\n");
@@ -627,14 +627,14 @@ int main(int argc, char ** argv) {
                 if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
                     inp_sfx.erase(inp_sfx.begin());
                 }
-                inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
+                inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
                 if (add_bos) {
-                    inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx));
+                    inp_pfx.insert(inp_pfx.begin(), llama_token_bos(model));
                 }
-                inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
+                inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
                 embd_inp = inp_pfx;
                 embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
-                embd_inp.push_back(llama_token_middle(ctx));
+                embd_inp.push_back(llama_token_middle(model));
                 embd.clear();
                 embd_guidance.clear();
                 n_remain = params.n_predict;
@@ -644,7 +644,7 @@ int main(int argc, char ** argv) {
                 is_interacting = false;
             }
             // deal with end of text token in interactive mode
-            else if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
+            else if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) {
                 LOG("found EOS token\n");
 
                 if (params.interactive) {
@@ -661,7 +661,7 @@ int main(int argc, char ** argv) {
 
                 if (params.input_prefix_bos) {
                     LOG("adding input prefix BOS token\n");
-                    embd_inp.push_back(llama_token_bos(ctx));
+                    embd_inp.push_back(llama_token_bos(model));
                 }
 
                 std::string buffer;
@@ -724,7 +724,7 @@ int main(int argc, char ** argv) {
         }
 
         // end of text token
-        if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !params.interactive) {
+        if (!embd.empty() && embd.back() == llama_token_eos(model) && !params.interactive) {
             break;
         }
 
@@ -736,7 +736,7 @@ int main(int argc, char ** argv) {
         }
     }
     if (!params.interactive && n_remain <= 0) {
-        printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
+        printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
         fflush(stdout);
     }
 
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index a04115c962655..20767d555b206 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -933,7 +933,7 @@ struct sql_printer : public printer {
 };
 
 static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
-    std::vector<llama_token> tokens(n_batch, llama_token_bos(ctx));
+    std::vector<llama_token> tokens(n_batch, llama_token_bos(llama_get_model(ctx)));
     int n_processed = 0;
 
     llama_set_n_threads(ctx, n_threads, n_threads);
@@ -946,7 +946,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
 }
 
 static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
-    llama_token token = llama_token_bos(ctx);
+    llama_token token = llama_token_bos(llama_get_model(ctx));
 
     llama_set_n_threads(ctx, n_threads, n_threads);
 
diff --git a/examples/llava/llava-utils.h b/examples/llava/llava-utils.h
index 45b2b1ad30226..320c719670b02 100644
--- a/examples/llava/llava-utils.h
+++ b/examples/llava/llava-utils.h
@@ -137,7 +137,7 @@ inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) {
 inline const char * sample(struct llama_context * ctx_llama, gpt_params & params, int * n_past) {
     int id = sample_id(ctx_llama, params);
     static std::string ret;
-    if (id == llama_token_eos(ctx_llama)) {
+    if (id == llama_token_eos(llama_get_model(ctx_llama))) {
         ret = "</s>";
     } else {
         ret = llama_token_to_piece(ctx_llama, id);
diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp
index eb64adef8a389..9a0b9c183d107 100644
--- a/examples/parallel/parallel.cpp
+++ b/examples/parallel/parallel.cpp
@@ -347,7 +347,7 @@ int main(int argc, char ** argv) {
                 //        client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
 
                 if (client.n_decoded > 2 &&
-                        (id == llama_token_eos(ctx) ||
+                        (id == llama_token_eos(model) ||
                          (params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) ||
                          client.response.find("User:") != std::string::npos ||
                          client.response.find('\n') != std::string::npos)) {
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index c3279dbc9c456..693f9b7735e49 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -726,7 +726,7 @@ struct llama_server_context
 
         if (json_value(data, "ignore_eos", false))
         {
-            slot->sparams.logit_bias[llama_token_eos(ctx)] = -INFINITY;
+            slot->sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
         }
 
         const auto &logit_bias = data.find("logit_bias");
@@ -1056,7 +1056,7 @@ struct llama_server_context
             slot.has_next_token = false;
         }
 
-        if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(ctx))
+        if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model))
         {
             slot.stopped_eos = true;
             slot.has_next_token = false;
@@ -1130,7 +1130,7 @@ struct llama_server_context
 
     json get_formated_generation(llama_client_slot &slot)
     {
-        const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(ctx));
+        const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
         const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() &&
                                 eos_bias->second < 0.0f && std::isinf(eos_bias->second);
         return json {
@@ -1555,11 +1555,11 @@ struct llama_server_context
                             suffix_tokens.erase(suffix_tokens.begin());
                         }
 
-                        prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));
-                        prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS
-                        prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
+                        prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
+                        prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
+                        prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
                         prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
-                        prefix_tokens.push_back(llama_token_middle(ctx));
+                        prefix_tokens.push_back(llama_token_middle(model));
                         prompt_tokens = prefix_tokens;
                     }
                     else
diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp
index 55385f566aa6f..f376c050994d1 100644
--- a/examples/simple/simple.cpp
+++ b/examples/simple/simple.cpp
@@ -138,7 +138,7 @@ int main(int argc, char ** argv) {
             const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
 
             // is it an end of stream?
-            if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
+            if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
                 LOG_TEE("\n");
 
                 break;
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp
index 894321ce9648c..92ad27e8e423c 100644
--- a/examples/speculative/speculative.cpp
+++ b/examples/speculative/speculative.cpp
@@ -163,7 +163,7 @@ int main(int argc, char ** argv) {
             printf("%s", token_str.c_str());
             fflush(stdout);
 
-            if (id == llama_token_eos(ctx_tgt)) {
+            if (id == llama_token_eos(model_tgt)) {
                 has_eos = true;
             }