Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

whisper : fine-tuning grammar functionality #1

Merged
merged 5 commits into from
Sep 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions examples/command/command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ struct whisper_params {
int32_t max_tokens = 32;
int32_t audio_ctx = 0;

float vad_thold = 0.6f;
float freq_thold = 100.0f;
float vad_thold = 0.6f;
float freq_thold = 100.0f;

float grammar_penalty = 100.0f;

bool speed_up = false;
Expand Down Expand Up @@ -138,6 +139,9 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;

// disable fallback - seems not useful for command recognition
wparams.temperature_inc = 0.0f;
Copy link
Author

@ggerganov ggerganov Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to make the command recognition example much more robust, although we want to enable it from time to time to make sure that multi-decoder grammar usage is not broken.

Copy link
Author

@ggerganov ggerganov Sep 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems the reason for this to improve the results is that the fallback logic depends on the average logprob of the generated sequences:

https://github.com/ggerganov/whisper.cpp/blob/b8f34d1ed786194d9787b1f1d086b89136e361f3/whisper.cpp#L5166-L5171

Therefore, I think both "best of" and "beam search" strategies would not work as expected.
We probably need to have some re-normalization of the logprobs after applying the grammar to make this compatible with these.

In any case, I think for now we should focus on greedy sampling without fallbacks and improve later if needed


wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;

Expand Down Expand Up @@ -508,7 +512,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi

// general-purpose mode
// freely transcribe the voice into text
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;
Expand All @@ -519,7 +523,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;

const std::string k_prompt = "Ok Whisper, start listening for commands.";
//const std::string k_prompt = "Ok Whisper, start listening for commands.";
//const std::string k_prompt = "Начало.";
const std::string k_prompt = "Добре Уиспър, започни да слушаш за команди.";

fprintf(stderr, "\n");
fprintf(stderr, "%s: general-purpose mode\n", __func__);
Expand Down Expand Up @@ -578,6 +584,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());

// append 1 second of silence
pcmf32_cur.insert(pcmf32_cur.end(), 1000*WHISPER_SAMPLE_RATE/1000, 0.0f);

ggerganov marked this conversation as resolved.
Show resolved Hide resolved
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));

prob = 100.0f*(prob - prob0);
Expand All @@ -604,6 +613,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
}
}

fprintf(stdout, "%s: DEBUG: txt = '%s'\n", __func__, txt.c_str());
if (best_len == 0) {
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
} else {
Expand Down
118 changes: 80 additions & 38 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3865,7 +3865,7 @@ static struct whisper_grammar whisper_grammar_init(
static void whisper_suppress_invalid_grammar(
whisper_context & ctx,
const whisper_full_params & params,
std::vector<float> & logits,
std::vector<float> & logprobs,
const whisper_grammar & grammar) {

if (grammar.rules.empty() || grammar.stacks.empty()) {
Expand All @@ -3883,8 +3883,8 @@ static void whisper_suppress_invalid_grammar(
std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
std::vector<whisper_grammar_candidate> candidates_grammar;

size_t size = logits.size();
for (whisper_token id = 0; id < size; ++id) {
size_t size = logprobs.size();
for (whisper_token id = 0; id < (int) size; ++id) {
const std::string & text = ctx.vocab.id_to_token[id];
if (!text.empty() && text.rfind("[_", 0) != 0) {
candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8));
Expand All @@ -3893,25 +3893,29 @@ static void whisper_suppress_invalid_grammar(
}

const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);

for (const auto & reject : rejects) {
if (logits[reject.id] > 0) {
logits[reject.id] /= params.grammar_penalty;
} else {
logits[reject.id] *= params.grammar_penalty;
}
logprobs[reject.id] -= params.grammar_penalty;
Copy link
Author

@ggerganov ggerganov Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes more sense, though I'm still experimenting.
At least for the use case where we want to strongly restrict the output to match the grammar, this works well with large penalty (>100.0)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this seems to work well enough - I've been testing this code with 10 - 15 on tiny. I was hoping for a sweet spot that would guide the decoding towards valid strings while leaving obviously non-matching strings unchanged. But so far, I've found that the point that gives satisfying grammar matching still seems too eager to match unrelated words to the grammar.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose both variants to fail in some way for different use cases. This penalty is essentially a free-parameter which it's not obvious how to set. A more general way would be to let the user provide a penalty callback function and let them implement the penalization, but I guess for now lets keep it simple

But so far, I've found that the point that gives satisfying grammar matching still seems too eager to match unrelated words to the grammar.

Does the original scaling work better in this scenario?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the original scaling work better in this scenario?

Anecdotally, no, not really.

}
// fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());

// when the grammar does not allow any continuation, we don't want to penalize the EOT token
// TODO: is there are better way to do this?
printf("rejects.size() = %zu, whisper_token_eot(&ctx) - 2 = %d\n", rejects.size(), whisper_token_eot(&ctx) - 2);
if ((int) rejects.size() < whisper_token_eot(&ctx) - 2) {
logprobs[whisper_token_eot(&ctx)] -= params.grammar_penalty;
}
//fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
Copy link
Author

@ggerganov ggerganov Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we have an issue with EOT handling.
Here I try to penalize EOT if the grammar allows continuation and this seems to help, although it's ugly.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's actually the start of that, commented out, at the top of the function. I was just unsure initially if that penalty made sense in whisper land.

The decoded string is complete according to the grammar if some stack is empty, so the following will penalize EOT if the decoded string is an incomplete parse. The grammar might still allow continuation after a complete parse, so if we want to only penalize EOT if there's no continuation at all, I guess we'd switch the check to find a nonempty stack.

diff --git a/whisper.cpp b/whisper.cpp
index 5e3b86a..e796ddd 100644
--- a/whisper.cpp
+++ b/whisper.cpp
@@ -3872,13 +3872,13 @@ static void whisper_suppress_invalid_grammar(
         return;
     }
 
-    // bool allow_eot = false;
-    // for (const auto & stack : grammar.stacks) {
-    //     if (stack.empty()) {
-    //         allow_eot = true;
-    //         break;
-    //     }
-    // }
+    bool allow_eot = false;
+    for (const auto & stack : grammar.stacks) {
+        if (stack.empty()) {
+            allow_eot = true;
+            break;
+        }
+    }
 
     std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
     std::vector<whisper_grammar_candidate>                              candidates_grammar;
@@ -3898,10 +3898,8 @@ static void whisper_suppress_invalid_grammar(
         logprobs[reject.id] -= params.grammar_penalty;
     }
 
-    // when the grammar does not allow any continuation, we don't want to penalize the EOT token
-    // TODO: is there are better way to do this?
-    printf("rejects.size() = %zu, whisper_token_eot(&ctx) - 2 = %d\n", rejects.size(), whisper_token_eot(&ctx) - 2);
-    if ((int) rejects.size() < whisper_token_eot(&ctx) - 2) {
+    // penalize EOT if grammar is incomplete
+    if (!allow_eot) {
         logprobs[whisper_token_eot(&ctx)] -= params.grammar_penalty;
     }
     //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());

}

static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) {
if (grammar.rules.empty() || grammar.stacks.empty()) {
return;
}

// fprintf(stderr, "Accept: '%s'", ctx.vocab.id_to_token[token].c_str());
fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str());

const std::string & text = ctx.vocab.id_to_token[token];

if (text.rfind("[_", 0) == 0) {
// fprintf(stderr, " (skipped)\n");
return;
Expand Down Expand Up @@ -4015,7 +4019,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.grammar_rules =*/ nullptr,
/*.n_grammar_rules =*/ 0,
/*.i_start_rule =*/ 0,
/*.grammar_penalty =*/ 1000.0f,
/*.grammar_penalty =*/ 100.0f,
};

switch (strategy) {
Expand Down Expand Up @@ -4181,12 +4185,18 @@ static void whisper_process_logits(
logits[vocab.token_translate] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY;

// suppress lang tokens
for (size_t i = 0; i < g_lang.size(); ++i) {
logits[whisper_token_lang(&ctx, i)] = -INFINITY;
}

// suppress prev token
logits[vocab.token_prev] = -INFINITY;

Comment on lines +4193 to +4200
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are irrelevant from the grammar functionality - I think improve the transcription in general.
Not sure why OpenAI don't have these filters in the original implementation (or maybe they do?)

if (params.logits_filter_callback) {
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
}

whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);

// suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
if (params.suppress_non_speech_tokens) {
Expand Down Expand Up @@ -4293,10 +4303,19 @@ static void whisper_process_logits(
//log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);

if (timestamp_logprob > max_text_token_logprob) {
//printf("sampling timestamp\n");
for (int i = 0; i < vocab.token_beg; ++i) {
logits[i] = -INFINITY;
logprobs[i] = -INFINITY;
}
} else {
//printf("sampling text\n");
for (int i = vocab.token_beg; i < n_logits; ++i) {
logits[i] = -INFINITY;
logprobs[i] = -INFINITY;
}

whisper_suppress_invalid_grammar(ctx, params, logprobs, decoder.grammar);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving this down here improves the results in my tests significantly. The idea is that we keep the logic for sampling a timestamp token intact, irrespective of what the grammar suggests. Only after we have made the decision about the next token being text, we apply the grammar filters.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! This was my first foray into whisper internals so I don't have a nuanced understanding of the decoding process.

}
}
}
Expand All @@ -4312,34 +4331,57 @@ static void whisper_process_logits(
}
}

#if 0
#if 1
// print first 100 logits - token string : logit
for (int i = 0; i < 100; i++) {
const auto token = vocab.id_to_token.at(i);
const auto prob = probs[i];
const auto logit = logits[i];
const auto logprob = logprobs[i];
printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
//for (int i = 0; i < 10; i++) {
// const auto token = vocab.id_to_token.at(i);
// const auto prob = probs[i];
// const auto logit = logits[i];
// const auto logprob = logprobs[i];
// printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
//}

// print sorted
{
std::vector<std::pair<float, int>> pairs;

for (int i = 0; i < n_logits; ++i) {
pairs.push_back(std::make_pair(probs[i], i));
}

std::sort(pairs.begin(), pairs.end(), [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
return a.first > b.first;
});

for (int i = 0; i < 10; i++) {
const auto token = vocab.id_to_token.at(pairs[i].second);
const auto prob = pairs[i].first;
const auto logit = logits[pairs[i].second];
const auto logprob = logprobs[pairs[i].second];
printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str());
}

printf("----------------\n");
}

// "And", "and", " And", " and"
printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);

printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);

printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
//printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
//printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
//printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
//printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
//printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);

//printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
//printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
//printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
//printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
//printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);

//printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
//printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
//printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
//printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
//printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
#endif
}

Expand Down