Skip to content

Commit

Permalink
fixed bug in dry sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
l3utterfly committed Apr 24, 2024
1 parent 99b7760 commit 75c37ed
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 40 deletions.
104 changes: 65 additions & 39 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12832,60 +12832,86 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
}
}

void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_token_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
// loop through each candidate
for (size_t i = 0; i < candidates->size; ++i) {
void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) {
// sanity check
GGML_ASSERT(last_tokens_size > 0);

// get the last token
auto last_token = last_tokens[last_tokens_size - 1];

// if last token is part of the sequence breakers, skip whole sampler
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, last_token) != seq_breakers + seq_breakers_size) {
return;
}

// if our candidate itself is part of the sequence breakers, we don't apply the dry penalty
if (std::find(seq_breakers, seq_breakers + seq_breakers_size, candidates->data[i].id) != seq_breakers + seq_breakers_size) {
// create an unordered map of "next tokens" <-> max match length
std::unordered_map<llama_token, size_t> match_lengths;

// loop through each previous token (exclude the last token)
for (size_t i = 0; i < last_tokens_size - 1; ++i) {
// skip if the compare token if it's not the same as the last token
if(last_tokens[i] != last_token) {
continue;
}

int max_match_length = 0;
// get the next token (i + 1 is always less than last_tokens_size)
auto next_token = last_tokens[i + 1];

// loop through each previous token
for (size_t j = 0; j < last_token_size; ++j) {
// if the current candidate is the same as the previous token
if (candidates->data[i].id == last_tokens[j]) {
// greedily match sequence backwards starting from the current position with the end of prev
int match_length = 1;
// try to extend the match backwards (match length starts a 1 because last token is already matched)
size_t match_length = 1;

// loop through the previous tokens
for(;; match_length++) {
// if we have reached the start of our stored prev, break
if(j - match_length > 0) break;
// loop through the previous tokens
for(;; match_length++) {
// if we have reached the start of our last tokens, break
if(i < match_length) break;

// this shouldn't happen because (j - match_length) should always be smaller than (size - match_length)
// but let's check here to avoid the unexpected
if(last_token_size - match_length < 0) break;
// compare token starts at our prev index, going backwards by match length
auto compare_token = last_tokens[i - match_length];

// compare token starts at our prev index, going backwards by match length
auto compare_token = last_tokens[j - match_length];
// head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
auto head_token = last_tokens[last_tokens_size - 1 - match_length];

// head token starts at the end of prev, going backwards by match length
auto head_token = last_tokens[last_token_size - match_length];
// if compare token is part of the sequence breakers, break out of the match
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
break;

// if compare token is part of the sequence breakers, break out of the match
if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size)
break;
// break out of the match if any tokens don't match
if(compare_token != head_token)
break;
}

// break out of the match if any tokens don't match
if(compare_token != head_token)
break;
}
// Check if the next token exists in the map
auto it = match_lengths.find(next_token);

// update our max match length
max_match_length = std::max(max_match_length, match_length);
}
if (it == match_lengths.end()) {
// Key does not exist, insert the new value
match_lengths[next_token] = match_length;
} else {
// Key exists, update it with the max of the new value or the existing value
it->second = std::max(it->second, match_length);
}
}

// apply penalties
if(max_match_length > dry_allowed_length) {
// calculate the penalty
float penalty = dry_multiplier * pow(dry_base, max_match_length - dry_allowed_length);
// apply penalties
for (const auto& pair : match_lengths) {
auto next_token = pair.first;
auto match_length = pair.second;

// apply the dry penalty
candidates->data[i].logit -= penalty;
// if the match length is greater than our allowed length in config, we apply penalities
if(match_length > dry_allowed_length) {

// find our next token in the candidates->data
size_t i = 0;
for (; i < candidates->size; ++i) {
if (candidates->data[i].id == next_token) {
// calculate the penalty
float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);

// apply the dry penalty
candidates->data[i].logit -= penalty;
break;
}
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,7 @@ extern "C" {
struct llama_context * ctx,
llama_token_data_array * candidates,
const llama_token * last_tokens,
int last_token_size,
int last_tokens_size,
float dry_base,
float dry_multiplier,
int dry_allowed_length,
Expand Down

0 comments on commit 75c37ed

Please sign in to comment.