-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
whisper : fine-tuning grammar functionality #1
Changes from 1 commit
b8f34d1
54d168d
7a2abb3
37de5dc
3c50be2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3865,7 +3865,7 @@ static struct whisper_grammar whisper_grammar_init( | |
static void whisper_suppress_invalid_grammar( | ||
whisper_context & ctx, | ||
const whisper_full_params & params, | ||
std::vector<float> & logits, | ||
std::vector<float> & logprobs, | ||
const whisper_grammar & grammar) { | ||
|
||
if (grammar.rules.empty() || grammar.stacks.empty()) { | ||
|
@@ -3883,8 +3883,8 @@ static void whisper_suppress_invalid_grammar( | |
std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded; | ||
std::vector<whisper_grammar_candidate> candidates_grammar; | ||
|
||
size_t size = logits.size(); | ||
for (whisper_token id = 0; id < size; ++id) { | ||
size_t size = logprobs.size(); | ||
for (whisper_token id = 0; id < (int) size; ++id) { | ||
const std::string & text = ctx.vocab.id_to_token[id]; | ||
if (!text.empty() && text.rfind("[_", 0) != 0) { | ||
candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8)); | ||
|
@@ -3893,25 +3893,29 @@ static void whisper_suppress_invalid_grammar( | |
} | ||
|
||
const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); | ||
|
||
for (const auto & reject : rejects) { | ||
if (logits[reject.id] > 0) { | ||
logits[reject.id] /= params.grammar_penalty; | ||
} else { | ||
logits[reject.id] *= params.grammar_penalty; | ||
} | ||
logprobs[reject.id] -= params.grammar_penalty; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this makes more sense, though I'm still experimenting. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, this seems to work well enough - I've been testing this code with 10 - 15 on tiny. I was hoping for a sweet spot that would guide the decoding towards valid strings while leaving obviously non-matching strings unchanged. But so far, I've found that the point that gives satisfying grammar matching still seems too eager to match unrelated words to the grammar. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose both variants to fail in some way for different use cases. This penalty is essentially a free-parameter which it's not obvious how to set. A more general way would be to let the user provide a penalty callback function and let them implement the penalization, but I guess for now lets keep it simple
Does the original scaling work better in this scenario? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Anecdotally, no, not really. |
||
} | ||
// fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); | ||
|
||
// when the grammar does not allow any continuation, we don't want to penalize the EOT token | ||
// TODO: is there are better way to do this? | ||
printf("rejects.size() = %zu, whisper_token_eot(&ctx) - 2 = %d\n", rejects.size(), whisper_token_eot(&ctx) - 2); | ||
if ((int) rejects.size() < whisper_token_eot(&ctx) - 2) { | ||
logprobs[whisper_token_eot(&ctx)] -= params.grammar_penalty; | ||
} | ||
//fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems we have an issue with EOT handling. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's actually the start of that, commented out, at the top of the function. I was just unsure initially if that penalty made sense in whisper land. The decoded string is complete according to the grammar if some stack is empty, so the following will penalize EOT if the decoded string is an incomplete parse. The grammar might still allow continuation after a complete parse, so if we want to only penalize EOT if there's no continuation at all, I guess we'd switch the check to find a nonempty stack. diff --git a/whisper.cpp b/whisper.cpp
index 5e3b86a..e796ddd 100644
--- a/whisper.cpp
+++ b/whisper.cpp
@@ -3872,13 +3872,13 @@ static void whisper_suppress_invalid_grammar(
return;
}
- // bool allow_eot = false;
- // for (const auto & stack : grammar.stacks) {
- // if (stack.empty()) {
- // allow_eot = true;
- // break;
- // }
- // }
+ bool allow_eot = false;
+ for (const auto & stack : grammar.stacks) {
+ if (stack.empty()) {
+ allow_eot = true;
+ break;
+ }
+ }
std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
std::vector<whisper_grammar_candidate> candidates_grammar;
@@ -3898,10 +3898,8 @@ static void whisper_suppress_invalid_grammar(
logprobs[reject.id] -= params.grammar_penalty;
}
- // when the grammar does not allow any continuation, we don't want to penalize the EOT token
- // TODO: is there are better way to do this?
- printf("rejects.size() = %zu, whisper_token_eot(&ctx) - 2 = %d\n", rejects.size(), whisper_token_eot(&ctx) - 2);
- if ((int) rejects.size() < whisper_token_eot(&ctx) - 2) {
+ // penalize EOT if grammar is incomplete
+ if (!allow_eot) {
logprobs[whisper_token_eot(&ctx)] -= params.grammar_penalty;
}
//fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); |
||
} | ||
|
||
static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) { | ||
if (grammar.rules.empty() || grammar.stacks.empty()) { | ||
return; | ||
} | ||
|
||
// fprintf(stderr, "Accept: '%s'", ctx.vocab.id_to_token[token].c_str()); | ||
fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str()); | ||
|
||
const std::string & text = ctx.vocab.id_to_token[token]; | ||
|
||
if (text.rfind("[_", 0) == 0) { | ||
// fprintf(stderr, " (skipped)\n"); | ||
return; | ||
|
@@ -4015,7 +4019,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str | |
/*.grammar_rules =*/ nullptr, | ||
/*.n_grammar_rules =*/ 0, | ||
/*.i_start_rule =*/ 0, | ||
/*.grammar_penalty =*/ 1000.0f, | ||
/*.grammar_penalty =*/ 100.0f, | ||
}; | ||
|
||
switch (strategy) { | ||
|
@@ -4181,12 +4185,18 @@ static void whisper_process_logits( | |
logits[vocab.token_translate] = -INFINITY; | ||
logits[vocab.token_transcribe] = -INFINITY; | ||
|
||
// suppress lang tokens | ||
for (size_t i = 0; i < g_lang.size(); ++i) { | ||
logits[whisper_token_lang(&ctx, i)] = -INFINITY; | ||
} | ||
|
||
// suppress prev token | ||
logits[vocab.token_prev] = -INFINITY; | ||
|
||
Comment on lines
+4193
to
+4200
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are irrelevant from the grammar functionality - I think improve the transcription in general. |
||
if (params.logits_filter_callback) { | ||
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); | ||
} | ||
|
||
whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); | ||
|
||
// suppress non-speech tokens | ||
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 | ||
if (params.suppress_non_speech_tokens) { | ||
|
@@ -4293,10 +4303,19 @@ static void whisper_process_logits( | |
//log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); | ||
|
||
if (timestamp_logprob > max_text_token_logprob) { | ||
//printf("sampling timestamp\n"); | ||
for (int i = 0; i < vocab.token_beg; ++i) { | ||
logits[i] = -INFINITY; | ||
logprobs[i] = -INFINITY; | ||
} | ||
} else { | ||
//printf("sampling text\n"); | ||
for (int i = vocab.token_beg; i < n_logits; ++i) { | ||
logits[i] = -INFINITY; | ||
logprobs[i] = -INFINITY; | ||
} | ||
|
||
whisper_suppress_invalid_grammar(ctx, params, logprobs, decoder.grammar); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moving this down here improves the results in my tests significantly. The idea is that we keep the logic for sampling a timestamp token intact, irrespective of what the grammar suggests. Only after we have made the decision about the next token being text, we apply the grammar filters. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense! This was my first foray into whisper internals so I don't have a nuanced understanding of the decoding process. |
||
} | ||
} | ||
} | ||
|
@@ -4312,34 +4331,57 @@ static void whisper_process_logits( | |
} | ||
} | ||
|
||
#if 0 | ||
#if 1 | ||
// print first 100 logits - token string : logit | ||
for (int i = 0; i < 100; i++) { | ||
const auto token = vocab.id_to_token.at(i); | ||
const auto prob = probs[i]; | ||
const auto logit = logits[i]; | ||
const auto logprob = logprobs[i]; | ||
printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); | ||
//for (int i = 0; i < 10; i++) { | ||
// const auto token = vocab.id_to_token.at(i); | ||
// const auto prob = probs[i]; | ||
// const auto logit = logits[i]; | ||
// const auto logprob = logprobs[i]; | ||
// printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); | ||
//} | ||
|
||
// print sorted | ||
{ | ||
std::vector<std::pair<float, int>> pairs; | ||
|
||
for (int i = 0; i < n_logits; ++i) { | ||
pairs.push_back(std::make_pair(probs[i], i)); | ||
} | ||
|
||
std::sort(pairs.begin(), pairs.end(), [](const std::pair<float, int>& a, const std::pair<float, int>& b) { | ||
return a.first > b.first; | ||
}); | ||
|
||
for (int i = 0; i < 10; i++) { | ||
const auto token = vocab.id_to_token.at(pairs[i].second); | ||
const auto prob = pairs[i].first; | ||
const auto logit = logits[pairs[i].second]; | ||
const auto logprob = logprobs[pairs[i].second]; | ||
printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str()); | ||
} | ||
|
||
printf("----------------\n"); | ||
} | ||
|
||
// "And", "and", " And", " and" | ||
printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); | ||
printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); | ||
printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); | ||
printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); | ||
printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); | ||
|
||
printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); | ||
printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); | ||
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); | ||
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); | ||
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); | ||
|
||
printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); | ||
printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); | ||
printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); | ||
printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); | ||
printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); | ||
//printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); | ||
//printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); | ||
//printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); | ||
//printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); | ||
//printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); | ||
|
||
//printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); | ||
//printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); | ||
//printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); | ||
//printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); | ||
//printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); | ||
|
||
//printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); | ||
//printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); | ||
//printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); | ||
//printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); | ||
//printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); | ||
#endif | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to make the command recognition example much more robust, although we want to enable it from time to time to make sure that multi-decoder grammar usage is not broken.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems the reason for this to improve the results is that the fallback logic depends on the average logprob of the generated sequences:
https://github.com/ggerganov/whisper.cpp/blob/b8f34d1ed786194d9787b1f1d086b89136e361f3/whisper.cpp#L5166-L5171
Therefore, I think both "best of" and "beam search" strategies would not work as expected.
We probably need to have some re-normalization of the logprobs after applying the grammar to make this compatible with these.
In any case, I think for now we should focus on greedy sampling without fallbacks and improve later if needed