From 4d603e3520b8cba69b99a27c9f9b0d77e0e36439 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Thu, 25 Apr 2024 15:58:59 +0900 Subject: [PATCH] added DRY implementation --- llama.cpp | 84 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/llama.cpp b/llama.cpp index 3a84b4916bd30..bb5aff46f800d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13233,6 +13233,90 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } } +void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * seq_breakers, int seq_breakers_size) { + // sanity check + GGML_ASSERT(last_tokens_size > 0); + + // get the last token + auto last_token = last_tokens[last_tokens_size - 1]; + + // if last token is part of the sequence breakers, skip whole sampler + if(std::find(seq_breakers, seq_breakers + seq_breakers_size, last_token) != seq_breakers + seq_breakers_size) { + return; + } + + // create an unordered map of "next tokens" <-> max match length + std::unordered_map match_lengths; + + // loop through each previous token (exclude the last token) + for (size_t i = 0; i < last_tokens_size - 1; ++i) { + // skip if the compare token if it's not the same as the last token + if(last_tokens[i] != last_token) { + continue; + } + + // get the next token (i + 1 is always less than last_tokens_size) + auto next_token = last_tokens[i + 1]; + + // try to extend the match backwards (match length starts a 1 because last token is already matched) + size_t match_length = 1; + + // loop through the previous tokens + for(;; match_length++) { + // if we have reached the start of our last tokens, break + if(i < match_length) break; + + // compare token starts at our prev index, going backwards by match length + auto compare_token = last_tokens[i - match_length]; + + // head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself + auto head_token = last_tokens[last_tokens_size - 1 - match_length]; + + // if compare token is part of the sequence breakers, break out of the match + if(std::find(seq_breakers, seq_breakers + seq_breakers_size, compare_token) != seq_breakers + seq_breakers_size) + break; + + // break out of the match if any tokens don't match + if(compare_token != head_token) + break; + } + + // Check if the next token exists in the map + auto it = match_lengths.find(next_token); + + if (it == match_lengths.end()) { + // Key does not exist, insert the new value + match_lengths[next_token] = match_length; + } else { + // Key exists, update it with the max of the new value or the existing value + it->second = std::max(it->second, match_length); + } + } + + // apply penalties + for (const auto& pair : match_lengths) { + auto next_token = pair.first; + auto match_length = pair.second; + + // if the match length is greater than our allowed length in config, we apply penalities + if(match_length > dry_allowed_length) { + + // find our next token in the candidates->data + size_t i = 0; + for (; i < candidates->size; ++i) { + if (candidates->data[i].id == next_token) { + // calculate the penalty + float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length); + + // apply the dry penalty + candidates->data[i].logit -= penalty; + break; + } + } + } + } +} + void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return;