Skip to content
This repository has been archived by the owner on Feb 6, 2024. It is now read-only.

Commit

Permalink
Add min-p sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
brittlewis12 committed Nov 17, 2023
1 parent 18dadad commit d03f651
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 0 deletions.
7 changes: 7 additions & 0 deletions Sources/llmfarm_core_cpp/ggml/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.top_k = std::stoi(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "--top_p") {
params.top_p = std::stof(get_next_arg(i, argc, argv, arg, params));
} else if (arg == "min_p") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.min_p = std::stof(argv[i]);
} else if (arg == "--temp") {
params.temp = std::stof(get_next_arg(i, argc, argv, arg, params));
params.temp = std::max(params.temp, 0.0f);
Expand Down Expand Up @@ -110,6 +116,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
fprintf(stderr, " --min_p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)params.min_p);
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);
Expand Down
2 changes: 2 additions & 0 deletions Sources/llmfarm_core_cpp/ggml/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ llama_token llama_sampling_sample(
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p;
const float min_p = params.min_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
Expand Down Expand Up @@ -143,6 +144,7 @@ llama_token llama_sampling_sample(
llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep);
llama_sample_typical (ctx, &cur_p, typical_p, min_keep);
llama_sample_top_p (ctx, &cur_p, top_p, min_keep);
llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
llama_sample_temp(ctx, &cur_p, temp);

{
Expand Down
1 change: 1 addition & 0 deletions Sources/llmfarm_core_cpp/ggml/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
typedef struct llama_sampling_params {
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // 1.0 = disabled
Expand Down
26 changes: 26 additions & 0 deletions Sources/llmfarm_core_cpp/llama/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7617,6 +7617,32 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can
}
}

void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
if (p <= 0.0f || !candidates->size) {
return;
}

llama_sample_softmax(ctx, candidates);

const int64_t t_start_sample_us = ggml_time_us();

float scale = candidates->data[0].p; // scale by max prob
size_t i = 1; // first token always matches

for (; i < candidates->size; ++i) {
if (candidates->data[i].p < p * scale && i >= min_keep) {
break; // prob too small
}
}

// Resize the output vector to keep only the matching tokens
candidates->size = i;

if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}

void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
if (z >= 1.0f || candidates->size <= 2) {
return;
Expand Down
7 changes: 7 additions & 0 deletions Sources/llmfarm_core_cpp/spm-headers/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,13 @@ extern "C" {
float p,
size_t min_keep);

/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
LLAMA_API void llama_sample_min_p(
struct llama_context * ctx,
llama_token_data_array * candidates,
float p,
size_t min_keep);

/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free(
struct llama_context * ctx,
Expand Down

0 comments on commit d03f651

Please sign in to comment.