Skip to content

Commit

Permalink
Add tail free and locally typical sampling params
Browse files Browse the repository at this point in the history
  • Loading branch information
imaami committed May 19, 2023
1 parent b4b7bb6 commit fea4747
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 23 deletions.
38 changes: 15 additions & 23 deletions gpt4all-backend/llamamodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include <thread>
#include <unordered_set>

#include <llama.h>
#include <ggml.h>

namespace {
Expand Down Expand Up @@ -57,16 +56,13 @@ struct LLamaPrivate {
bool empty = true;
};

static int llama_sample_top_p_top_k(
llama_context *ctx,
const llama_token *last_n_tokens_data,
int last_n_tokens_size,
int top_k,
float top_p,
float temp,
float repeat_penalty) {
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
llama_token LLamaModel::sample_top_p_top_k(PromptContext &promptCtx)
{
const auto last_n_tokens_size = std::min(static_cast<std::size_t>(promptCtx.repeat_last_n), promptCtx.tokens.size());
const auto *last_n_tokens_data = &promptCtx.tokens.data()[promptCtx.tokens.size() - last_n_tokens_size];

auto logits = llama_get_logits(d_ptr->ctx);
auto n_vocab = llama_n_vocab(d_ptr->ctx);
// Populate initial list of all candidates
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
Expand All @@ -75,14 +71,14 @@ static int llama_sample_top_p_top_k(
}
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
// Sample repeat penalty
llama_sample_repetition_penalty(nullptr, &candidates_p, last_n_tokens_data, last_n_tokens_size, repeat_penalty);
llama_sample_repetition_penalty(nullptr, &candidates_p, last_n_tokens_data, last_n_tokens_size, promptCtx.repeat_penalty);
// Temperature sampling
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_tail_free(ctx, &candidates_p, 1.0f, 1);
llama_sample_typical(ctx, &candidates_p, 1.0f, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_temperature(ctx, &candidates_p, temp);
return llama_sample_token(ctx, &candidates_p);
llama_sample_top_k(d_ptr->ctx, &candidates_p, promptCtx.top_k, 1);
llama_sample_tail_free(d_ptr->ctx, &candidates_p, promptCtx.tfs_z, 1);
llama_sample_typical(d_ptr->ctx, &candidates_p, promptCtx.typical_p, 1);
llama_sample_top_p(d_ptr->ctx, &candidates_p, promptCtx.top_p, 1);
llama_sample_temperature(d_ptr->ctx, &candidates_p, promptCtx.temp);
return llama_sample_token(d_ptr->ctx, &candidates_p);
}

LLamaModel::LLamaModel()
Expand Down Expand Up @@ -233,11 +229,7 @@ void LLamaModel::prompt(const std::string &prompt,
int32_t totalPredictions = 0;
for (int i = 0; i < promptCtx.n_predict; i++) {
// sample next token
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
llama_token id = llama_sample_top_p_top_k(d_ptr->ctx,
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
promptCtx.repeat_penalty);
llama_token id = sample_top_p_top_k(promptCtx);

// Check if the context has run out...
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
Expand Down
2 changes: 2 additions & 0 deletions gpt4all-backend/llamamodel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <string>
#include <functional>
#include <vector>
#include <llama.h>
#include "llmodel.h"

struct LLamaPrivate;
Expand All @@ -33,6 +34,7 @@ class LLamaModel : public LLModel {
std::function<bool(bool)> recalculate) override;

private:
llama_token sample_top_p_top_k(PromptContext &promptCtx);
LLamaPrivate *d_ptr;
};

Expand Down
2 changes: 2 additions & 0 deletions gpt4all-backend/llmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class LLModel {
int32_t top_k = 40;
float top_p = 0.9f;
float temp = 0.9f;
float tfs_z = 1.0f; // tail free sampling, parameter z (1.0 = disabled)
float typical_p = 1.0f; // locally typical sampling, parameter p (1.0 = disabled)
int32_t n_batch = 9;
float repeat_penalty = 1.10f;
int32_t repeat_last_n = 64; // last n tokens to penalize
Expand Down
4 changes: 4 additions & 0 deletions gpt4all-backend/llmodel_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
wrapper->promptContext.top_k = ctx->top_k;
wrapper->promptContext.top_p = ctx->top_p;
wrapper->promptContext.temp = ctx->temp;
wrapper->promptContext.tfs_z = ctx->tfs_z;
wrapper->promptContext.typical_p = ctx->typical_p;
wrapper->promptContext.n_batch = ctx->n_batch;
wrapper->promptContext.repeat_penalty = ctx->repeat_penalty;
wrapper->promptContext.repeat_last_n = ctx->repeat_last_n;
Expand All @@ -145,6 +147,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
ctx->top_k = wrapper->promptContext.top_k;
ctx->top_p = wrapper->promptContext.top_p;
ctx->temp = wrapper->promptContext.temp;
ctx->tfs_z = wrapper->promptContext.tfs_z;
ctx->typical_p = wrapper->promptContext.typical_p;
ctx->n_batch = wrapper->promptContext.n_batch;
ctx->repeat_penalty = wrapper->promptContext.repeat_penalty;
ctx->repeat_last_n = wrapper->promptContext.repeat_last_n;
Expand Down
2 changes: 2 additions & 0 deletions gpt4all-backend/llmodel_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ struct llmodel_prompt_context {
int32_t top_k; // top k logits to sample from
float top_p; // nucleus sampling probability threshold
float temp; // temperature to adjust model's output distribution
float tfs_z; // tail free sampling, parameter z (1.0 = disabled)
float typical_p; // locally typical sampling, parameter p (1.0 = disabled)
int32_t n_batch; // number of predictions to generate in parallel
float repeat_penalty; // penalty factor for repeated tokens
int32_t repeat_last_n; // last n tokens to penalize
Expand Down
2 changes: 2 additions & 0 deletions gpt4all-bindings/golang/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_la
.top_k = 10,
.top_p = 0.9,
.temp = 1.0,
.tfs_z = 1.0f,
.typical_p = 1.0f,
.n_batch = 1,
.repeat_penalty = 1.2,
.repeat_last_n = 10,
Expand Down

0 comments on commit fea4747

Please sign in to comment.