Skip to content

Commit

Permalink
whisper : improve beam search candidate diversity (ggerganov#1947)
Browse files Browse the repository at this point in the history
As of ggerganov#1486, whisper.cpp uses a unified KV cache with KQ masking.
As a result, depending on their location in the batch,
identical sequences in a batch can have slightly different outputs
due to floating point rounding errors during reduction.
See the discussion in ggerganov#1941 for more details.

The beam search code used "has identical sum of log probabilities"
as a shorthand for "is an identical token sequence". However, per above,
identical tokens do not necessarily result in identical probabilities.

Instead, explicitly compare on sequences.
This is linear in cost when they are identical,
but the lengths are always small and the comparisons are cheap.

This increases diversity during beam search.

This improves output quality for some short samples I've been working
with, at no detectable performance cost.
I haven't checked against larger corpuses.

Fixes ggerganov#1941
  • Loading branch information
josharian authored Mar 10, 2024
1 parent 730a8b3 commit 54e3674
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4759,6 +4759,19 @@ static void whisper_process_logits(
#endif
}

static bool whisper_sequence_tokens_equal(const whisper_sequence & a, const whisper_sequence & b) {
if (a.tokens.size() != b.tokens.size()) {
return false;
}
// sequences are more likely to diverge at the end
for (int i = a.tokens.size() - 1; i >= 0; i--) {
if (a.tokens[i].id != b.tokens[i].id) {
return false;
}
}
return true;
}

static whisper_token_data whisper_sample_token(
whisper_context & ctx,
const whisper_decoder & decoder,
Expand Down Expand Up @@ -5378,7 +5391,7 @@ int whisper_full_with_state(

auto & cur = beam_candidates[cur_c++];

while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
while (beam_candidates.size() > cur_c && whisper_sequence_tokens_equal(beam_candidates[cur_c].sequence, cur.sequence) && i > 0) {
++cur_c;
}

Expand Down

0 comments on commit 54e3674

Please sign in to comment.