Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Min P sampler implementation [alternative to Top P/Top K] #3841

Merged
merged 25 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
59d1232
cuda : prints wip
ggerganov Oct 25, 2023
52af782
cuda : new cublas gemm branch for multi-batch quantized src0
ggerganov Oct 25, 2023
16b60dd
cuda : add F32 sgemm branch
ggerganov Oct 25, 2023
a3c2843
cuda : fine-tune >= VOLTA params + use MMQ only for small batches
ggerganov Oct 25, 2023
4c6744b
cuda : remove duplicated cuBLAS GEMM code
ggerganov Oct 25, 2023
a4e15a3
cuda : add CUDA_USE_TENSOR_CORES and GGML_CUDA_FORCE_MMQ macros
ggerganov Oct 25, 2023
49af767
build : add compile option to force use of MMQ kernels
ggerganov Oct 27, 2023
a9e2b74
Super hacky starting implementation of Min P
kalomaze Oct 28, 2023
a235a0d
Transform Min P into a proper CLI option
kalomaze Oct 29, 2023
838d58d
Min P disabled if set to 1.0 or 0, otherwise Top P
kalomaze Oct 29, 2023
69ef4ca
Debugging print statements removed
kalomaze Oct 29, 2023
833637b
erring on the side of caution; disable by default
kalomaze Oct 29, 2023
62fc771
Remove accidentally kept prints + min_keep support
kalomaze Oct 29, 2023
49b68e8
Standardize 0.0 disabling min_p upon feedback
kalomaze Oct 29, 2023
6f7cdec
Simplified counter by checking candidates size
kalomaze Oct 29, 2023
cb23358
minor whitespace fix
kalomaze Oct 29, 2023
fcbbfc1
Even formatting + exclusively 0.0f to disable now
kalomaze Oct 29, 2023
69e638e
cleanup
cebtenzzre Oct 29, 2023
3ddfd67
permit simultaneous use of top_p and min_p
cebtenzzre Oct 29, 2023
18c0aa7
Merge remote-tracking branch 'original/cuda-quantum-batch' into min-p…
kalomaze Oct 29, 2023
87adfad
Merge branch 'min-p-sampling' of https://github.com/kalomaze/koboldcp…
kalomaze Oct 29, 2023
9248325
Update README & set 0.05 default
kalomaze Oct 31, 2023
512cac6
added a bit more context to the README
kalomaze Oct 31, 2023
974640a
Update README for consistency
kalomaze Oct 31, 2023
3b58af2
forgot one small thing!
kalomaze Oct 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/common.cpp
kalomaze marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
sparams.top_p = std::stof(argv[i]);
} else if (arg == "--min-p") { // Adding min_p argument
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.min_p = std::stof(argv[i]); // Parsing and setting the min_p value from command line
} else if (arg == "--temp") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -679,6 +685,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
Expand Down Expand Up @@ -1275,6 +1282,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency());
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "min_p: %f # default: 0.05\n", sparams.min_p);
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
}
6 changes: 4 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) {

snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
params.mirostat, params.mirostat_eta, params.mirostat_tau);

return std::string(result);
Expand All @@ -110,6 +110,7 @@ llama_token llama_sampling_sample(
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p;
const float min_p = params.min_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
Expand Down Expand Up @@ -190,6 +191,7 @@ llama_token llama_sampling_sample(
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice that min_p comes after all samplers, meaning that sampling would be among top_k->tfs->typical->top_p results, or, as we have enabled almost every sampler, among top_k->top_p results. So, it is among Top K=40, among them Top P 95%, and only then Min P

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That order seems logical. Min-P is supposed to be a Top-P replacement, but if you were to use it together with other samplers it would fulfill the same role as Top-P and should have the same order. And the default being 0 makes sense since Top-P already has a default value, you don't want to combine both actively at the same time.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, it probably shouldn't just be enabled by default. I don't think it was like that when I looked at it, but maybe I missed it.

why we enable by default almost all samplers? If someone want to use just one sampler, the person have to disable all samplers and penalties

Maybe add a --disable-all-samplers option or something? I think the default settings are aimed at generally producing decent results, probably not the case with all samplers disabled.

Copy link
Contributor

@Mihaiii Mihaiii Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't want to combine both actively at the same time.

Why not? To my understanding, it wouldn't be misleading to apply a Min-P on the top logits that combined make a certain procentage since Min-P takes into consideration the procentage of the first (highest probability) logit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it does seem like the default is currently 0.05, which probably should be changed to 0.

llama_sample_temp (ctx_main, &cur_p, temp);

id = llama_sample_token(ctx_main, &cur_p);
Expand Down
1 change: 1 addition & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ typedef struct llama_sampling_params {
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // 1.0 = disabled
Expand Down
57 changes: 56 additions & 1 deletion llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7332,7 +7332,7 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can
if (p >= 1.0f) {
return;
}

llama_sample_softmax(ctx, candidates);

const int64_t t_start_sample_us = ggml_time_us();
Expand Down Expand Up @@ -7360,6 +7360,61 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can
}
}

void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
kalomaze marked this conversation as resolved.
Show resolved Hide resolved
float base_min_p = p; // This will hold the base minimum probability value
float multiplied_min_p; // This will hold the adjusted minimum probability threshold

printf("\nUSING MIN P SAMPLING MODE\n\n");

// Ensure the probabilities are calculated.
llama_sample_softmax(ctx, candidates);

// Print the top tokens before filtering
printf("Top tokens before filtering:\n");
for (size_t i = 0; i < candidates->size && i < 10; ++i) {
printf("Token %zu: %.6f%%\n", i + 1, candidates->data[i].p * 100); // Multiplying by 100 to convert to percentage
}

// Calculate the multiplication factor based on the highest scoring token.
float multiplication_factor = candidates->data[0].p; // Assuming the probabilities are sorted
printf("Highest scoring token probability (multiplication factor): %f\n", multiplication_factor);

// Calculate the dynamic threshold.
multiplied_min_p = base_min_p * multiplication_factor;
printf("Base min_p value: %f\n", base_min_p);
printf("Calculated multiplied_min_p (threshold) value: %f\n", multiplied_min_p);

// Store the tokens that meet the threshold in a new list.
std::vector<llama_token_data> filtered_candidates;
filtered_candidates.reserve(candidates->size); // Reserve to avoid multiple reallocations

// Variable to count how many tokens meet the condition
int count_qualifying_tokens = 0;

for (size_t i = 0; i < candidates->size; ++i) {
// If a token's probability is above the threshold, we keep it.
if (candidates->data[i].p >= multiplied_min_p) {
filtered_candidates.push_back(candidates->data[i]);
++count_qualifying_tokens; // Increase count
}
}

// Debug information about how many tokens were retained
printf("Number of tokens that met the multiplied_min_p condition: %d\n", count_qualifying_tokens);

// Print the top tokens after filtering
printf("Tokens after filtering:\n\n");
for (size_t i = 0; i < filtered_candidates.size() && i < 10; ++i) { // Adjust 10 to however many top tokens you want to display
printf("Token %zu: %.6f%%\n", i + 1, filtered_candidates[i].p * 100); // Multiplying by 100 to convert to percentage
}

// Now we replace the original candidates with the filtered list.
std::copy(filtered_candidates.begin(), filtered_candidates.end(), candidates->data);
candidates->size = filtered_candidates.size();

return;
}

void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
if (z >= 1.0f || candidates->size <= 2) {
return;
Expand Down
7 changes: 7 additions & 0 deletions llama.h
cebtenzzre marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,13 @@ extern "C" {
float p,
size_t min_keep);

/// @details Minimum P sampling by Kalomaze
LLAMA_API void llama_sample_min_p(
struct llama_context * ctx,
llama_token_data_array * candidates,
float p,
size_t min_keep);

/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free(
struct llama_context * ctx,
Expand Down