Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added implementation of DRY sampler #6839

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f64dea0
added implementation of DRY sampler
l3utterfly Apr 25, 2024
aea4ad0
fixed editor config check
l3utterfly Apr 25, 2024
4d603e3
added DRY implementation
l3utterfly Apr 25, 2024
75beda2
fixed various issues with sampler pointed out by original creator
l3utterfly Apr 29, 2024
85dadac
added parameter for DRY penalty range, separate from the original rep…
l3utterfly Apr 29, 2024
793e1e2
updated header def for dry sampler to match implementation
l3utterfly Apr 29, 2024
3caec6b
removed unused llama_context in dry sampler
l3utterfly Apr 29, 2024
49e078f
changed array size parameters to size_t
l3utterfly Apr 29, 2024
2f9a36a
Merge branch 'master' into dry-sampler
l3utterfly Jul 29, 2024
802ddd7
added sample_dry_impl
l3utterfly Jul 29, 2024
12bfa78
added llama_sample_dry_impl in header
l3utterfly Jul 29, 2024
0229fc8
added final new line for editor config check
l3utterfly Jul 29, 2024
236da59
fixed int/size_t comparison
l3utterfly Jul 29, 2024
e862def
use int32_t for dry_penalty_last_n due to negative value needed as co…
l3utterfly Jul 29, 2024
9105cf4
Add DRY sampling parameters to gpt_params and server_context
wwoodsTM Aug 5, 2024
20dc562
Delete pr-6839.diff
wwoodsTM Aug 5, 2024
d1676a1
Merge pull request #29 from wwoodsTM/test-dry-sampler
l3utterfly Aug 6, 2024
ed6b909
Merge branch 'master' into dry-sampler
l3utterfly Aug 6, 2024
6579e64
Attempt at slightly optimized vector of strings DRY implementation
wwoodsTM Aug 6, 2024
a18fb2f
Merge remote-tracking branch 'myfork/test-dry-sampler' into test-dry-…
wwoodsTM Aug 6, 2024
190898a
Merge pull request #30 from wwoodsTM/test-dry-sampler
l3utterfly Aug 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,26 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
sparams.penalty_present = std::stof(argv[i]);
return true;
}
if (arg == "--dry-multiplier") {
CHECK_ARG
sparams.dry_multiplier = std::stof(argv[i]);
return true;
}
if (arg == "--dry-base") {
CHECK_ARG
sparams.dry_base = std::stof(argv[i]);
return true;
}
if (arg == "--dry-allowed-length") {
CHECK_ARG
sparams.dry_allowed_length = std::stoi(argv[i]);
return true;
}
if (arg == "--dry-penalty-last-n") {
CHECK_ARG
sparams.dry_penalty_last_n = std::stoi(argv[i]);
return true;
}
if (arg == "--dynatemp-range") {
CHECK_ARG
sparams.dynatemp_range = std::stof(argv[i]);
Expand Down Expand Up @@ -1471,6 +1491,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat });
options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present });
options.push_back({ "*", " --frequency-penalty N", "repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_freq });
options.push_back({ "*", " --dry-multiplier N", "DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)sparams.dry_multiplier });
options.push_back({ "*", " --dry-base N", "DRY sampling base (default: %.1f)", (double)sparams.dry_base });
options.push_back({ "*", " --dry-allowed-length N", "DRY sampling allowed length (default: %d)", sparams.dry_allowed_length });
options.push_back({ "*", " --dry-penalty-last-n N", "DRY sampling penalty last n tokens (-1 = context size, default: %d)", sparams.dry_penalty_last_n });

options.push_back({ "*", " --dynatemp-range N", "dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)sparams.dynatemp_range });
options.push_back({ "*", " --dynatemp-exp N", "dynamic temperature exponent (default: %.1f)", (double)sparams.dynatemp_exponent });
options.push_back({ "*", " --mirostat N", "use Mirostat sampling.\n"
Expand Down
51 changes: 36 additions & 15 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,19 @@ static llama_token_data_array llama_sampling_prepare_impl(

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));

// repetition penalties
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;

const bool penalize_nl = params.penalize_nl;

// DRY sampler parameters
const float dry_multiplier = params.dry_multiplier;
const float dry_base = params.dry_base;
const uint32_t dry_allowed_length = params.dry_allowed_length;
const uint32_t dry_penalty_last_n = params.dry_penalty_last_n;

auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;

Expand Down Expand Up @@ -399,26 +405,41 @@ static llama_token_data_array llama_sampling_prepare_impl(

llama_token_data_array cur_p = { cur.data(), cur.size(), false };

// apply penalties
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
if (penalty_tokens_used_size) {
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];

llama_sample_repetition_penalties(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);

if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
cur_p.data[idx].logit = nl_logit;
break;

// apply repetition penalties
{
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
if (penalty_tokens_used_size) {
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];

// repetition penalties
llama_sample_repetition_penalties(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);

if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
cur_p.data[idx].logit = nl_logit;
break;
}
}
}
}
}

// apply DRY penalties
{
const int penalty_tokens_used_size = std::min(penalty_tokens.size(), (size_t)dry_penalty_last_n);
if (penalty_tokens_used_size) {
llama_sample_dry(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
params.dry_seq_breakers);
}
}

// apply grammar checks before sampling logic
if (apply_grammar && ctx_sampling->grammar != NULL) {
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
Expand Down
8 changes: 7 additions & 1 deletion common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ typedef struct llama_sampling_params {
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
float dry_base = 1.75f;
uint32_t dry_allowed_length = 2;
int32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size)

std::vector<std::string> dry_seq_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY

std::vector<llama_sampler_type> samplers_sequence = {
llama_sampler_type::TOP_K,
Expand All @@ -59,8 +65,8 @@ typedef struct llama_sampling_params {
float cfg_scale = 1.f; // how strong is guidance

std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens

std::vector<llama_token> penalty_prompt_tokens;

bool use_penalty_prompt_tokens = false;
} llama_sampling_params;

Expand Down
77 changes: 53 additions & 24 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,30 +901,54 @@ struct server_context {
slot.oaicompat_model = "";
}

slot.params.stream = json_value(data, "stream", false);
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
slot.params.stream = json_value(data, "stream", false);
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict);
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);

// sequence breakers for DRY
{
auto dry_seq_breakers = data.find("dry_seq_breakers");
if (dry_seq_breakers != data.end()) {
try {
if (dry_seq_breakers->is_array()) {
slot.sparams.dry_seq_breakers = dry_seq_breakers->get<std::vector<std::string>>();
} else if (dry_seq_breakers->is_string()) {
slot.sparams.dry_seq_breakers = json::parse(dry_seq_breakers->get<std::string>()).get<std::vector<std::string>>();
} else {
send_error(task, "\"dry_seq_breakers\": Expected an array of strings or a JSON-encoded array of strings.", ERROR_TYPE_INVALID_REQUEST);
return false;
}
} catch (const std::exception & e) {
send_error(task, std::string("\"dry_seq_breakers\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
}

// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
Expand Down Expand Up @@ -1342,6 +1366,11 @@ struct server_context {
{"frequency_penalty", slot.sparams.penalty_freq},
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
{"dry_multiplier", slot.sparams.dry_multiplier},
{"dry_base", slot.sparams.dry_base},
{"dry_allowed_length", slot.sparams.dry_allowed_length},
{"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
{"dry_seq_breakers", slot.sparams.dry_seq_breakers},
{"mirostat", slot.sparams.mirostat},
{"mirostat_tau", slot.sparams.mirostat_tau},
{"mirostat_eta", slot.sparams.mirostat_eta},
Expand Down
24 changes: 24 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,18 @@ extern "C" {
float p,
size_t min_keep);

// /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
// LLAMA_API void llama_sample_dry(
// struct llama_context * ctx,
// llama_token_data_array * candidates,
// const llama_token * last_tokens,
// size_t last_tokens_size,
// float dry_base,
// float dry_multiplier,
// int dry_allowed_length,
// const std::vector<std::string>
// & dry_seq_breakers);

/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free(
struct llama_context * ctx,
Expand Down Expand Up @@ -1235,6 +1247,18 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);

/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
LLAMA_API void llama_sample_dry(
struct llama_context * ctx,
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t last_tokens_size,
float dry_base,
float dry_multiplier,
int dry_allowed_length,
const std::vector<std::string>
& dry_seq_breakers);

#endif // LLAMA_API_INTERNAL

#endif // LLAMA_H
Loading
Loading