Skip to content

Commit 46fdd23

Browse files
ggerganovslaren
authored andcommitted
sampling : refactor + optimize penalties sampler (ggml-org#10803)
* sampling : refactor + optimize penalties sampler ggml-ci * common : apply ignore_eos as logit bias ggml-ci * batched : remove penalties sampler * params : allow penalty_last_n == -1 to be equal to context size ggml-ci * common : by default, move the penalties at the end of the sampling chain ggml-ci * common : ignore all EOG tokens Co-authored-by: Diego Devesa <slarengh@gmail.com> * common : move back the penalties at the front of the sampling chain ggml-ci * readme : restore hint about --ignore-eos flag [no ci] * llama : minor ggml-ci * webui : update --------- Co-authored-by: Diego Devesa <slarengh@gmail.com>
1 parent 415666f commit 46fdd23

File tree

17 files changed

+111
-152
lines changed

17 files changed

+111
-152
lines changed

Diff for: common/arg.cpp

+6-7
Original file line numberDiff line numberDiff line change
@@ -855,13 +855,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
855855
params.sampling.ignore_eos = true;
856856
}
857857
).set_sparam());
858-
add_opt(common_arg(
859-
{"--penalize-nl"},
860-
string_format("penalize newline tokens (default: %s)", params.sampling.penalize_nl ? "true" : "false"),
861-
[](common_params & params) {
862-
params.sampling.penalize_nl = true;
863-
}
864-
).set_sparam());
865858
add_opt(common_arg(
866859
{"--temp"}, "N",
867860
string_format("temperature (default: %.1f)", (double)params.sampling.temp),
@@ -916,6 +909,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
916909
{"--repeat-last-n"}, "N",
917910
string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sampling.penalty_last_n),
918911
[](common_params & params, int value) {
912+
if (value < -1) {
913+
throw std::runtime_error(string_format("error: invalid repeat-last-n = %d\n", value));
914+
}
919915
params.sampling.penalty_last_n = value;
920916
params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
921917
}
@@ -970,6 +966,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
970966
{"--dry-penalty-last-n"}, "N",
971967
string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sampling.dry_penalty_last_n),
972968
[](common_params & params, int value) {
969+
if (value < -1) {
970+
throw std::runtime_error(string_format("error: invalid dry-penalty-last-n = %d\n", value));
971+
}
973972
params.sampling.dry_penalty_last_n = value;
974973
}
975974
).set_sparam());

Diff for: common/common.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,25 @@ struct common_init_result common_init_from_params(common_params & params) {
940940
params.sampling.ignore_eos = false;
941941
}
942942

943+
if (params.sampling.ignore_eos) {
944+
for (llama_token i = 0; i < llama_n_vocab(model); i++) {
945+
if (llama_token_is_eog(model, i)) {
946+
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
947+
params.sampling.logit_bias.push_back({i, -INFINITY});
948+
}
949+
}
950+
}
951+
952+
if (params.sampling.penalty_last_n == -1) {
953+
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
954+
params.sampling.penalty_last_n = llama_n_ctx(lctx);
955+
}
956+
957+
if (params.sampling.dry_penalty_last_n == -1) {
958+
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
959+
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
960+
}
961+
943962
if (params.warmup) {
944963
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
945964

Diff for: common/common.h

+9-6
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ enum common_sampler_type {
9595
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
9696
COMMON_SAMPLER_TYPE_XTC = 8,
9797
COMMON_SAMPLER_TYPE_INFILL = 9,
98+
COMMON_SAMPLER_TYPE_PENALTIES = 10,
9899
};
99100

100101
// dimensionality reduction methods, used by cvector-generator
@@ -130,7 +131,6 @@ struct common_params_sampling {
130131
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
131132
float mirostat_tau = 5.00f; // target entropy
132133
float mirostat_eta = 0.10f; // learning rate
133-
bool penalize_nl = false; // consider newlines as a repeatable token
134134
bool ignore_eos = false;
135135
bool no_perf = false; // disable performance metrics
136136
bool timing_per_token = false;
@@ -139,6 +139,7 @@ struct common_params_sampling {
139139

140140

141141
std::vector<enum common_sampler_type> samplers = {
142+
COMMON_SAMPLER_TYPE_PENALTIES,
142143
COMMON_SAMPLER_TYPE_DRY,
143144
COMMON_SAMPLER_TYPE_TOP_K,
144145
COMMON_SAMPLER_TYPE_TYPICAL_P,
@@ -193,11 +194,13 @@ struct common_params {
193194
float defrag_thold = 0.1f; // KV cache defragmentation threshold
194195

195196
// offload params
196-
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
197-
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
198-
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
199-
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
200-
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
197+
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
198+
199+
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
200+
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
201+
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
202+
203+
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
201204

202205
struct cpu_params cpuparams;
203206
struct cpu_params cpuparams_batch;

Diff for: common/sampling.cpp

+11-16
Original file line numberDiff line numberDiff line change
@@ -161,32 +161,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
161161
params.logit_bias.size(),
162162
params.logit_bias.data()));
163163

164-
llama_sampler_chain_add(result->chain,
165-
llama_sampler_init_penalties(
166-
llama_n_vocab (model),
167-
llama_token_eos(model),
168-
llama_token_nl (model),
169-
params.penalty_last_n,
170-
params.penalty_repeat,
171-
params.penalty_freq,
172-
params.penalty_present,
173-
params.penalize_nl,
174-
params.ignore_eos));
175-
176164
if (params.mirostat == 0) {
177165
for (const auto & cnstr : params.samplers) {
178166
switch (cnstr) {
179-
case COMMON_SAMPLER_TYPE_DRY:
167+
case COMMON_SAMPLER_TYPE_DRY:
180168
{
181-
std::vector<const char*> c_breakers;
169+
std::vector<const char *> c_breakers;
182170
c_breakers.reserve(params.dry_sequence_breakers.size());
183-
for (const auto& str : params.dry_sequence_breakers) {
171+
for (const auto & str : params.dry_sequence_breakers) {
184172
c_breakers.push_back(str.c_str());
185173
}
186174

187175
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
188176
}
189-
break;
177+
break;
190178
case COMMON_SAMPLER_TYPE_TOP_K:
191179
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
192180
break;
@@ -208,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
208196
case COMMON_SAMPLER_TYPE_INFILL:
209197
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
210198
break;
199+
case COMMON_SAMPLER_TYPE_PENALTIES:
200+
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
201+
break;
211202
default:
212203
GGML_ASSERT(false && "unknown sampler type");
213204
}
@@ -415,6 +406,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
415406
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
416407
case COMMON_SAMPLER_TYPE_XTC: return 'x';
417408
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
409+
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
418410
default : return '?';
419411
}
420412
}
@@ -429,6 +421,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
429421
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
430422
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
431423
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
424+
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
432425
default : return "";
433426
}
434427
}
@@ -443,6 +436,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
443436
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
444437
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
445438
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
439+
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
446440
};
447441

448442
// since samplers names are written multiple ways
@@ -489,6 +483,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
489483
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
490484
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
491485
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
486+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
492487
};
493488

494489
std::vector<common_sampler_type> samplers;

Diff for: examples/batched/batched.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ int main(int argc, char ** argv) {
6565
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
6666

6767
auto sparams = llama_sampler_chain_default_params();
68+
sparams.no_perf = false;
6869

6970
llama_sampler * smpl = llama_sampler_chain_init(sparams);
7071

Diff for: examples/main/README.md

-5
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,11 @@ Example usage: `--temp 0`
177177

178178
- `--repeat-penalty N`: Control the repetition of token sequences in the generated text default: 1.0, 1.0 = disabled).
179179
- `--repeat-last-n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size).
180-
- `--no-penalize-nl`: Disable penalization for newline tokens when applying the repeat penalty.
181180

182181
The `repeat-penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1.
183182

184183
The `repeat-last-n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx-size`).
185184

186-
Use the `--no-penalize-nl` option to disable newline penalization when applying the repeat penalty. This option is particularly useful for generating chat conversations, dialogues, code, poetry, or any text where newline tokens play a significant role in structure and formatting. Disabling newline penalization helps maintain the natural flow and intended formatting in these specific use cases.
187-
188-
Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl`
189-
190185
### DRY Repetition Penalty
191186

192187
DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns (original [PR link](https://github.com/oobabooga/text-generation-webui/pull/5677)).

Diff for: examples/server/README.md

-5
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ The project is under active development, and we are [looking for feedback and co
104104
| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) |
105105
| `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: dkypmxt) |
106106
| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
107-
| `--penalize-nl` | penalize newline tokens (default: false) |
108107
| `--temp N` | temperature (default: 0.8) |
109108
| `--top-k N` | top-k sampling (default: 40, 0 = disabled) |
110109
| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) |
@@ -393,8 +392,6 @@ These words will not be included in the completion, so make sure to add them to
393392

394393
`repeat_last_n`: Last n tokens to consider for penalizing repetition. Default: `64`, where `0` is disabled and `-1` is ctx-size.
395394

396-
`penalize_nl`: Penalize newline tokens when applying the repeat penalty. Default: `true`
397-
398395
`presence_penalty`: Repeat alpha presence penalty. Default: `0.0`, which is disabled.
399396

400397
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
@@ -655,7 +652,6 @@ This endpoint is public (no API key check). By default, it is read-only. To make
655652
"mirostat": 0,
656653
"mirostat_tau": 5.0,
657654
"mirostat_eta": 0.10000000149011612,
658-
"penalize_nl": false,
659655
"stop": [],
660656
"max_tokens": -1,
661657
"n_keep": 0,
@@ -845,7 +841,6 @@ Example:
845841
"mirostat": 0,
846842
"mirostat_tau": 5.0,
847843
"mirostat_eta": 0.10000000149011612,
848-
"penalize_nl": false,
849844
"stop": [],
850845
"max_tokens": -1,
851846
"n_keep": 0,

Diff for: examples/server/public/index.html.gz

2 Bytes
Binary file not shown.

Diff for: examples/server/public_legacy/index-new.html

-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
temperature: 0.8, // adapt all following parameters to optimized min-p requierements. If for non-english, set to 0.6 or lower
4040
repeat_last_n: 0, // 0 = disable penalty, -1 = context size
4141
repeat_penalty: 1.0, // 1.0 = disabled
42-
penalize_nl: false, // true only useful for infinite completion
4342
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
4443
dry_base: 1.75, // 0.0 = disabled
4544
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well

Diff for: examples/server/public_legacy/index.html

-2
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,6 @@
303303
temperature: 0.7,
304304
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
305305
repeat_penalty: 1.18, // 1.0 = disabled
306-
penalize_nl: false,
307306
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
308307
dry_base: 1.75, // 0.0 = disabled
309308
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
@@ -1006,7 +1005,6 @@
10061005
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
10071006
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
10081007
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
1009-
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
10101008
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
10111009
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
10121010
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}

Diff for: examples/server/server.cpp

+23-5
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ struct slot_params {
135135
{"mirostat", sampling.mirostat},
136136
{"mirostat_tau", sampling.mirostat_tau},
137137
{"mirostat_eta", sampling.mirostat_eta},
138-
{"penalize_nl", sampling.penalize_nl},
139138
{"stop", antiprompt},
140139
{"max_tokens", n_predict}, // User configured n_predict
141140
{"n_keep", n_keep},
@@ -184,6 +183,7 @@ struct server_task {
184183

185184
static slot_params params_from_json_cmpl(
186185
const llama_model * model,
186+
const llama_context * ctx,
187187
const common_params & params_base,
188188
const json & data) {
189189
slot_params params;
@@ -226,7 +226,6 @@ struct server_task {
226226
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
227227
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
228228
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
229-
params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
230229
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
231230
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
232231
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
@@ -239,8 +238,27 @@ struct server_task {
239238
params.speculative.n_min = std::max(params.speculative.n_min, 2);
240239
params.speculative.n_max = std::max(params.speculative.n_max, 0);
241240

241+
// TODO: add more sanity checks for the input parameters
242+
243+
if (params.sampling.penalty_last_n < -1) {
244+
throw std::runtime_error("Error: repeat_last_n must be >= -1");
245+
}
246+
247+
if (params.sampling.dry_penalty_last_n < -1) {
248+
throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
249+
}
250+
251+
if (params.sampling.penalty_last_n == -1) {
252+
// note: should be the slot's context and not the full context, but it's ok
253+
params.sampling.penalty_last_n = llama_n_ctx(ctx);
254+
}
255+
256+
if (params.sampling.dry_penalty_last_n == -1) {
257+
params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
258+
}
259+
242260
if (params.sampling.dry_base < 1.0f) {
243-
params.sampling.dry_base = defaults.sampling.dry_base;
261+
params.sampling.dry_base = defaults.sampling.dry_base;
244262
}
245263

246264
// sequence breakers for DRY
@@ -1469,7 +1487,7 @@ struct server_context {
14691487
n_ctx = llama_n_ctx(ctx);
14701488

14711489
add_bos_token = llama_add_bos_token(model);
1472-
has_eos_token = !llama_add_eos_token(model);
1490+
has_eos_token = llama_token_eos(model) != LLAMA_TOKEN_NULL;
14731491

14741492
if (!params_base.speculative.model.empty()) {
14751493
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
@@ -3381,7 +3399,7 @@ int main(int argc, char ** argv) {
33813399
task.index = i;
33823400

33833401
task.prompt_tokens = std::move(tokenized_prompts[i]);
3384-
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.params_base, data);
3402+
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
33853403
task.id_selected_slot = json_value(data, "id_slot", -1);
33863404

33873405
// OAI-compat

Diff for: examples/server/themes/buttons-top/index.html

-2
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,6 @@
222222
temperature: 0.7,
223223
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
224224
repeat_penalty: 1.18, // 1.0 = disabled
225-
penalize_nl: false,
226225
top_k: 40, // <= 0 to use vocab size
227226
top_p: 0.95, // 1.0 = disabled
228227
min_p: 0.05, // 0 = disabled
@@ -779,7 +778,6 @@
779778
${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
780779
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
781780
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
782-
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
783781
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
784782
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
785783
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}

0 commit comments

Comments
 (0)