diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index 128f14290e32..9791eafca24d 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -25,7 +25,6 @@ #include #include -#include #include namespace { @@ -57,16 +56,13 @@ struct LLamaPrivate { bool empty = true; }; -static int llama_sample_top_p_top_k( - llama_context *ctx, - const llama_token *last_n_tokens_data, - int last_n_tokens_size, - int top_k, - float top_p, - float temp, - float repeat_penalty) { - auto logits = llama_get_logits(ctx); - auto n_vocab = llama_n_vocab(ctx); +llama_token LLamaModel::sample_top_p_top_k(PromptContext &promptCtx) +{ + const auto last_n_tokens_size = std::min(static_cast(promptCtx.repeat_last_n), promptCtx.tokens.size()); + const auto *last_n_tokens_data = &promptCtx.tokens.data()[promptCtx.tokens.size() - last_n_tokens_size]; + + auto logits = llama_get_logits(d_ptr->ctx); + auto n_vocab = llama_n_vocab(d_ptr->ctx); // Populate initial list of all candidates std::vector candidates; candidates.reserve(n_vocab); @@ -75,14 +71,14 @@ static int llama_sample_top_p_top_k( } llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; // Sample repeat penalty - llama_sample_repetition_penalty(nullptr, &candidates_p, last_n_tokens_data, last_n_tokens_size, repeat_penalty); + llama_sample_repetition_penalty(nullptr, &candidates_p, last_n_tokens_data, last_n_tokens_size, promptCtx.repeat_penalty); // Temperature sampling - llama_sample_top_k(ctx, &candidates_p, top_k, 1); - llama_sample_tail_free(ctx, &candidates_p, 1.0f, 1); - llama_sample_typical(ctx, &candidates_p, 1.0f, 1); - llama_sample_top_p(ctx, &candidates_p, top_p, 1); - llama_sample_temperature(ctx, &candidates_p, temp); - return llama_sample_token(ctx, &candidates_p); + llama_sample_top_k(d_ptr->ctx, &candidates_p, promptCtx.top_k, 1); + llama_sample_tail_free(d_ptr->ctx, &candidates_p, promptCtx.tfs_z, 1); + llama_sample_typical(d_ptr->ctx, &candidates_p, promptCtx.typical_p, 1); + llama_sample_top_p(d_ptr->ctx, &candidates_p, promptCtx.top_p, 1); + llama_sample_temperature(d_ptr->ctx, &candidates_p, promptCtx.temp); + return llama_sample_token(d_ptr->ctx, &candidates_p); } LLamaModel::LLamaModel() @@ -233,11 +229,7 @@ void LLamaModel::prompt(const std::string &prompt, int32_t totalPredictions = 0; for (int i = 0; i < promptCtx.n_predict; i++) { // sample next token - const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size()); - llama_token id = llama_sample_top_p_top_k(d_ptr->ctx, - promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks, - n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.temp, - promptCtx.repeat_penalty); + llama_token id = sample_top_p_top_k(promptCtx); // Check if the context has run out... if (promptCtx.n_past + 1 > promptCtx.n_ctx) { diff --git a/gpt4all-backend/llamamodel_impl.h b/gpt4all-backend/llamamodel_impl.h index a4e5d99cd216..024d2b054047 100644 --- a/gpt4all-backend/llamamodel_impl.h +++ b/gpt4all-backend/llamamodel_impl.h @@ -7,6 +7,7 @@ #include #include #include +#include #include "llmodel.h" struct LLamaPrivate; @@ -33,6 +34,7 @@ class LLamaModel : public LLModel { std::function recalculate) override; private: + llama_token sample_top_p_top_k(PromptContext &promptCtx); LLamaPrivate *d_ptr; }; diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index a9f8d16c59cd..5e4e388f4aaa 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -27,6 +27,8 @@ class LLModel { int32_t top_k = 40; float top_p = 0.9f; float temp = 0.9f; + float tfs_z = 1.0f; // tail free sampling, parameter z (1.0 = disabled) + float typical_p = 1.0f; // locally typical sampling, parameter p (1.0 = disabled) int32_t n_batch = 9; float repeat_penalty = 1.10f; int32_t repeat_last_n = 64; // last n tokens to penalize diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index 439b4bba6f1d..8e0ee4d406c8 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -123,6 +123,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt, wrapper->promptContext.top_k = ctx->top_k; wrapper->promptContext.top_p = ctx->top_p; wrapper->promptContext.temp = ctx->temp; + wrapper->promptContext.tfs_z = ctx->tfs_z; + wrapper->promptContext.typical_p = ctx->typical_p; wrapper->promptContext.n_batch = ctx->n_batch; wrapper->promptContext.repeat_penalty = ctx->repeat_penalty; wrapper->promptContext.repeat_last_n = ctx->repeat_last_n; @@ -145,6 +147,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt, ctx->top_k = wrapper->promptContext.top_k; ctx->top_p = wrapper->promptContext.top_p; ctx->temp = wrapper->promptContext.temp; + ctx->tfs_z = wrapper->promptContext.tfs_z; + ctx->typical_p = wrapper->promptContext.typical_p; ctx->n_batch = wrapper->promptContext.n_batch; ctx->repeat_penalty = wrapper->promptContext.repeat_penalty; ctx->repeat_last_n = wrapper->promptContext.repeat_last_n; diff --git a/gpt4all-backend/llmodel_c.h b/gpt4all-backend/llmodel_c.h index 58d2669c1580..7ddc2083e901 100644 --- a/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/llmodel_c.h @@ -42,6 +42,8 @@ struct llmodel_prompt_context { int32_t top_k; // top k logits to sample from float top_p; // nucleus sampling probability threshold float temp; // temperature to adjust model's output distribution + float tfs_z; // tail free sampling, parameter z (1.0 = disabled) + float typical_p; // locally typical sampling, parameter p (1.0 = disabled) int32_t n_batch; // number of predictions to generate in parallel float repeat_penalty; // penalty factor for repeated tokens int32_t repeat_last_n; // last n tokens to penalize diff --git a/gpt4all-bindings/golang/binding.cpp b/gpt4all-bindings/golang/binding.cpp index 867117ef9d12..554c69fea987 100644 --- a/gpt4all-bindings/golang/binding.cpp +++ b/gpt4all-bindings/golang/binding.cpp @@ -93,6 +93,8 @@ void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_la .top_k = 10, .top_p = 0.9, .temp = 1.0, + .tfs_z = 1.0f, + .typical_p = 1.0f, .n_batch = 1, .repeat_penalty = 1.2, .repeat_last_n = 10,