Skip to content

Working implementation of DRY with one key issue I could use help with #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,10 +433,10 @@ static llama_token_data_array llama_sampling_prepare_impl(
{
const int penalty_tokens_used_size = std::min(penalty_tokens.size(), (size_t)dry_penalty_last_n);
if (penalty_tokens_used_size) {
llama_sample_dry(&cur_p,
llama_sample_dry(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
params.dry_seq_breakers.data(), params.dry_seq_breakers.size());
params.dry_seq_breakers);
}
}

Expand Down
5 changes: 3 additions & 2 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ typedef struct llama_sampling_params {
uint32_t dry_allowed_length = 2;
int32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size)

std::vector<std::string> dry_seq_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY

std::vector<llama_sampler_type> samplers_sequence = {
llama_sampler_type::TOP_K,
llama_sampler_type::TFS_Z,
Expand All @@ -63,9 +65,8 @@ typedef struct llama_sampling_params {
float cfg_scale = 1.f; // how strong is guidance

std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens

std::vector<llama_token> penalty_prompt_tokens;
std::vector<llama_token> dry_seq_breakers; // sequence breakers for the DRY sampler

bool use_penalty_prompt_tokens = false;
} llama_sampling_params;

Expand Down
33 changes: 23 additions & 10 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1085,16 +1085,17 @@ extern "C" {
float p,
size_t min_keep);

/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
LLAMA_API void llama_sample_dry(
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t last_tokens_size,
float dry_base,
float dry_multiplier,
int dry_allowed_length,
const llama_token * dry_seq_breakers,
size_t dry_seq_breakers_size);
// /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
// LLAMA_API void llama_sample_dry(
// struct llama_context * ctx,
// llama_token_data_array * candidates,
// const llama_token * last_tokens,
// size_t last_tokens_size,
// float dry_base,
// float dry_multiplier,
// int dry_allowed_length,
// const std::vector<std::string>
// & dry_seq_breakers);

/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free(
Expand Down Expand Up @@ -1246,6 +1247,18 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);

/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
LLAMA_API void llama_sample_dry(
struct llama_context * ctx,
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t last_tokens_size,
float dry_base,
float dry_multiplier,
int dry_allowed_length,
const std::vector<std::string>
& dry_seq_breakers);

#endif // LLAMA_API_INTERNAL

#endif // LLAMA_H
210 changes: 173 additions & 37 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,94 +232,230 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra
}
}

void llama_sample_dry_impl(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) {
// skip dry sampler if we don't have a previous token
if (last_tokens_size < 1) return;
std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_special,
bool parse_special) {
return llama_tokenize(llama_get_model(ctx), text, add_special, parse_special);
}

std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_special,
bool parse_special) {
// upper limit for the number of tokens
int n_tokens = text.length() + 2 * add_special;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
}
return result;
}

std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
std::string text;
text.resize(std::max(text.capacity(), tokens.size()));
int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
if (n_chars < 0) {
text.resize(-n_chars);
n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
}

text.resize(n_chars);

// NOTE: the original tokenizer decodes bytes after collecting the pieces.
return text;
}

std::string llama_detokenize_single(llama_context * ctx, llama_token token, bool special) {
std::vector<llama_token> tokens = {token};
return llama_detokenize(ctx, tokens, special);
}

// get the last token
auto last_token = last_tokens[last_tokens_size - 1];
// Constants for preventing overflow
const float FLOAT_MAX_LOG = 88.7228391f;
const int MAX_CHAR_LEN = 40;
const int MAX_SEQ_LEN = 20;

// if last token is part of the sequence breakers, skip whole sampler
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) {

void llama_sample_dry_impl(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector<std::string> & dry_seq_breakers) {
if (last_tokens_size < 1) {
return;
}

// create an unordered map of "next tokens" <-> max match length
// Cache for token-to-string conversions
std::unordered_map<llama_token, std::string> token_to_string_cache;
// Store sequence breakers for more efficient lookup
std::unordered_multimap<std::string, std::vector<std::string>> restart_sequences;

auto detokenize_with_cache = [&](llama_token token) -> std::string {
auto it = token_to_string_cache.find(token);
if (it != token_to_string_cache.end()) {
return it->second;
}
std::string token_str = llama_detokenize_single(ctx, token, false);
token_to_string_cache[token] = token_str;
return token_str;
};

// Pre-process dry_seq_breakers
for (const auto& breaker : dry_seq_breakers) {
std::string breaker_trimmed = breaker.substr(0, MAX_CHAR_LEN);
std::vector<llama_token> tokens = llama_tokenize(ctx, breaker_trimmed, false, false);

if (!tokens.empty()) {
std::string head = detokenize_with_cache(tokens[0]);
std::vector<std::string> tail;

for (size_t i = 1; i < tokens.size() && i <= MAX_SEQ_LEN; ++i) {
tail.push_back(detokenize_with_cache(tokens[i]));
}
restart_sequences.emplace(head, tail);
}
}

// Find max repetition length considering restart sequences
int rep_limit = last_tokens_size;

for (size_t i = 0; i < last_tokens_size; ++i) {
size_t ix = last_tokens_size - 1 - i;
std::string token_str = detokenize_with_cache(last_tokens[ix]);

// Check if the token is a potential sequence breaker
auto its = restart_sequences.equal_range(token_str);
if (its.first == restart_sequences.end()) continue;

int longest_match = -1;
// Check all potential sequence breakers starting with this token
for (auto it = its.first; it != its.second; ++it) {
int seq_len = (int)it->second.size();
if (seq_len > longest_match && seq_len <= i) {
bool match = true;
// Check if the following tokens match the sequence breaker
for (size_t offset = 0; offset < seq_len; ++offset) {
if (it->second[offset] != detokenize_with_cache(last_tokens[ix + 1 + offset])) {
match = false;
break;
}
}
if (match) {
longest_match = seq_len;
}
}
}

if (longest_match >= 0) {
rep_limit = static_cast<int>(i) - longest_match;
break;
}
}

if (rep_limit <= dry_allowed_length) {
return;
}

// Store max match length for each token
std::unordered_map<llama_token, size_t> match_lengths;

// loop through each previous token (exclude the last token)
// Find repeated sequences
for (size_t i = 0; i < last_tokens_size - 1; ++i) {
// skip if the compare token is not the same as the last token
if (last_tokens[i] != last_token) {
if (last_tokens[i] != last_tokens[last_tokens_size - 1]) {
continue;
}

// get the next token (i + 1 is always less than last_tokens_size)
auto next_token = last_tokens[i + 1];
std::string next_token_str = detokenize_with_cache(next_token);

// if next token is part of the sequence breakers, skip
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) {
// Skip if next token is a sequence breaker
auto its = restart_sequences.equal_range(next_token_str);
if (its.first != restart_sequences.end()) {
continue;
}

// try to extend the match backwards (match length starts at 1 because last token is already matched)
size_t match_length = 1;

// loop through the previous tokens
// Extend match as far as possible
for (;; match_length++) {
// if we have reached the start of our last tokens, break
if (i < match_length) break;
if (i < match_length || match_length > rep_limit) {
break;
}

// compare token starts at our prev index, going backwards by match length
auto compare_token = last_tokens[i - match_length];
std::string compare_token_str = detokenize_with_cache(compare_token);

// head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
auto head_token = last_tokens[last_tokens_size - 1 - match_length];
std::string head_token_str = detokenize_with_cache(head_token);

// break out of the match if any tokens don't match
if (compare_token != head_token) {
if (compare_token_str != head_token_str) {
break;
}

// if compare token is part of the sequence breakers, break out of the match
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) {
// Check if we've hit a sequence breaker
its = restart_sequences.equal_range(compare_token_str);
if (its.first != restart_sequences.end()) {
break;
}
}

// Check if the next token exists in the map
// Update max match length for this token
auto it = match_lengths.find(next_token);

if (it == match_lengths.end()) {
// Key does not exist, insert the new value
match_lengths[next_token] = match_length;
} else {
// Key exists, update it with the max of the new value or the existing value
it->second = std::max(it->second, match_length);
}
}

// apply penalties
// Calculate max safe exponent
int max_exponent = 0;
if (dry_base > 1.000001f) {
max_exponent = static_cast<int>(FLOAT_MAX_LOG / log(dry_base));
}

#ifdef DEBUG
LLAMA_LOG_INFO("DRY Sampling parameters:\n");
LLAMA_LOG_INFO(" dry_base: %f\n", dry_base);
LLAMA_LOG_INFO(" dry_multiplier: %f\n", dry_multiplier);
LLAMA_LOG_INFO(" dry_allowed_length: %d\n", dry_allowed_length);
LLAMA_LOG_INFO(" max_exponent: %d\n", max_exponent);
LLAMA_LOG_INFO("DRY penalties [");
#endif

// Apply penalties
for (const auto& pair : match_lengths) {
auto next_token = pair.first;
auto match_length = pair.second;

// if the match length is greater than or equal to our allowed length in config, we apply penalities
if (match_length >= (size_t)dry_allowed_length) {

// find our next token in the candidates->data
if (match_length >= static_cast<size_t>(dry_allowed_length)) {
for (size_t i = 0; i < candidates->size; ++i) {
if (candidates->data[i].id == next_token) {
// calculate the penalty
float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);

// apply the dry penalty
int repeat_exp = static_cast<int>(match_length - dry_allowed_length);
if (max_exponent > 0 && repeat_exp > max_exponent) {
repeat_exp = max_exponent;
}
float penalty = dry_multiplier * pow(dry_base, static_cast<float>(repeat_exp));
candidates->data[i].logit -= penalty;

#ifdef DEBUG
LLAMA_LOG_INFO(" Token %d: %s (Penalty: %.2f)", next_token, detokenize_with_cache(next_token).c_str(), penalty);
#endif
break;
}
}
}
}

#ifdef DEBUG
LLAMA_LOG_INFO("]\n");
#endif
}

void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
Expand Down
6 changes: 5 additions & 1 deletion src/llama-sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_
void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_dry_impl (llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size);
std::vector<llama_token> llama_tokenize(const struct llama_context * ctx, const std::string & text, bool add_special, bool parse_special);
std::vector<llama_token> llama_tokenize(const struct llama_model * model, const std::string & text, bool add_special, bool parse_special);
std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special);
std::string llama_detokenize_single(llama_context * ctx, llama_token token, bool special);
void llama_sample_dry_impl (struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector<std::string> & dry_seq_breakers);
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
Expand Down
4 changes: 2 additions & 2 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18935,8 +18935,8 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
}

void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) {
llama_sample_dry_impl(candidates, last_tokens, last_tokens_size, dry_base, dry_multiplier, dry_allowed_length, dry_seq_breakers, dry_seq_breakers_size);
void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector<std::string> & dry_seq_breakers) {
llama_sample_dry_impl(ctx, candidates, last_tokens, last_tokens_size, dry_base, dry_multiplier, dry_allowed_length, dry_seq_breakers);
}

void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
Expand Down
Loading