Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

whisper : fine-tuning grammar functionality #1

Merged
merged 5 commits into from
Sep 10, 2023

Conversation

ggerganov
Copy link

@ggerganov ggerganov commented Sep 6, 2023

This is just a draft. If it works OK, we will merge it to ggerganov#1229

I believe I have improved the grammar behaviour. See comments in the diff below.
At least this works quite robust with the "red", "green", "blue" command example both in English and Bulgarian

I want to play with some more sophisticated experiments and also to make sure I didn't break something else.
Let me know what you think about these changes - if you give this branch a try, you might want to disable the verbose prints in the #if 1

@@ -138,6 +139,9 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;

// disable fallback - seems not useful for command recognition
wparams.temperature_inc = 0.0f;
Copy link
Author

@ggerganov ggerganov Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to make the command recognition example much more robust, although we want to enable it from time to time to make sure that multi-decoder grammar usage is not broken.

Copy link
Author

@ggerganov ggerganov Sep 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems the reason for this to improve the results is that the fallback logic depends on the average logprob of the generated sequences:

https://github.com/ggerganov/whisper.cpp/blob/b8f34d1ed786194d9787b1f1d086b89136e361f3/whisper.cpp#L5166-L5171

Therefore, I think both "best of" and "beam search" strategies would not work as expected.
We probably need to have some re-normalization of the logprobs after applying the grammar to make this compatible with these.

In any case, I think for now we should focus on greedy sampling without fallbacks and improve later if needed

whisper.cpp Outdated
Comment on lines 3896 to 3898

for (const auto & reject : rejects) {
if (logits[reject.id] > 0) {
logits[reject.id] /= params.grammar_penalty;
} else {
logits[reject.id] *= params.grammar_penalty;
}
logprobs[reject.id] -= params.grammar_penalty;
Copy link
Author

@ggerganov ggerganov Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes more sense, though I'm still experimenting.
At least for the use case where we want to strongly restrict the output to match the grammar, this works well with large penalty (>100.0)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this seems to work well enough - I've been testing this code with 10 - 15 on tiny. I was hoping for a sweet spot that would guide the decoding towards valid strings while leaving obviously non-matching strings unchanged. But so far, I've found that the point that gives satisfying grammar matching still seems too eager to match unrelated words to the grammar.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose both variants to fail in some way for different use cases. This penalty is essentially a free-parameter which it's not obvious how to set. A more general way would be to let the user provide a penalty callback function and let them implement the penalization, but I guess for now lets keep it simple

But so far, I've found that the point that gives satisfying grammar matching still seems too eager to match unrelated words to the grammar.

Does the original scaling work better in this scenario?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the original scaling work better in this scenario?

Anecdotally, no, not really.

whisper.cpp Outdated
Comment on lines 3901 to 3907
// when the grammar does not allow any continuation, we don't want to penalize the EOT token
// TODO: is there are better way to do this?
printf("rejects.size() = %zu, whisper_token_eot(&ctx) - 2 = %d\n", rejects.size(), whisper_token_eot(&ctx) - 2);
if ((int) rejects.size() < whisper_token_eot(&ctx) - 2) {
logprobs[whisper_token_eot(&ctx)] -= params.grammar_penalty;
}
//fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
Copy link
Author

@ggerganov ggerganov Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we have an issue with EOT handling.
Here I try to penalize EOT if the grammar allows continuation and this seems to help, although it's ugly.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's actually the start of that, commented out, at the top of the function. I was just unsure initially if that penalty made sense in whisper land.

The decoded string is complete according to the grammar if some stack is empty, so the following will penalize EOT if the decoded string is an incomplete parse. The grammar might still allow continuation after a complete parse, so if we want to only penalize EOT if there's no continuation at all, I guess we'd switch the check to find a nonempty stack.

diff --git a/whisper.cpp b/whisper.cpp
index 5e3b86a..e796ddd 100644
--- a/whisper.cpp
+++ b/whisper.cpp
@@ -3872,13 +3872,13 @@ static void whisper_suppress_invalid_grammar(
         return;
     }
 
-    // bool allow_eot = false;
-    // for (const auto & stack : grammar.stacks) {
-    //     if (stack.empty()) {
-    //         allow_eot = true;
-    //         break;
-    //     }
-    // }
+    bool allow_eot = false;
+    for (const auto & stack : grammar.stacks) {
+        if (stack.empty()) {
+            allow_eot = true;
+            break;
+        }
+    }
 
     std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
     std::vector<whisper_grammar_candidate>                              candidates_grammar;
@@ -3898,10 +3898,8 @@ static void whisper_suppress_invalid_grammar(
         logprobs[reject.id] -= params.grammar_penalty;
     }
 
-    // when the grammar does not allow any continuation, we don't want to penalize the EOT token
-    // TODO: is there are better way to do this?
-    printf("rejects.size() = %zu, whisper_token_eot(&ctx) - 2 = %d\n", rejects.size(), whisper_token_eot(&ctx) - 2);
-    if ((int) rejects.size() < whisper_token_eot(&ctx) - 2) {
+    // penalize EOT if grammar is incomplete
+    if (!allow_eot) {
         logprobs[whisper_token_eot(&ctx)] -= params.grammar_penalty;
     }
     //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());

Comment on lines +4188 to +4195
// suppress lang tokens
for (size_t i = 0; i < g_lang.size(); ++i) {
logits[whisper_token_lang(&ctx, i)] = -INFINITY;
}

// suppress prev token
logits[vocab.token_prev] = -INFINITY;

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are irrelevant from the grammar functionality - I think improve the transcription in general.
Not sure why OpenAI don't have these filters in the original implementation (or maybe they do?)

whisper.cpp Outdated
logprobs[i] = -INFINITY;
}

whisper_suppress_invalid_grammar(ctx, params, logprobs, decoder.grammar);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving this down here improves the results in my tests significantly. The idea is that we keep the logic for sampling a timestamp token intact, irrespective of what the grammar suggests. Only after we have made the decision about the next token being text, we apply the grammar filters.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! This was my first foray into whisper internals so I don't have a nuanced understanding of the decoding process.

@ejones
Copy link
Owner

ejones commented Sep 7, 2023

Looking pretty good! I've been testing this draft with the chess grammar from my demo and a more complex home assistant grammar (the latter courtesy of GPT), using tiny-en and --grammar-penalty 15

Chess grammar

root ::= " Ok Whisper, start listening for commands. " ("Bishop " | "Rook " | "Knight " | "Queen " | "King " | "Pawn ") "to " [a-h] [1-8] (" and promotes to " ("bishop" | "rook" | "knight" | "queen" | "king"))? "."
Home assistant grammar
# - "turn on lights."
# - "set thermostat to 22."
# - "increase TV by 10."
# - "decrease oven by 50."
# - "play music."
# - "stop podcast."
# - "schedule cleaning at 3pm."
# - "cancel cleaning."
# - "remind me to buy milk at 5pm."
# - "show me security system."
# - "hide washing machine."
# - "what is the lights status?"
# - "what is the current thermostat value?"
# - "what is the security system status?"
# - "what is the door lock status?"
# - "what is the camera battery level?"
# - "what is the weather like today?"
# - "what is the forecast for tomorrow?"
# - "what is the time?"
# - "what is my schedule for today?"
# - "what tasks do I have?"
# - "what reminders do I have?"

root ::= " Ok Whisper, start listening for commands. " (command | question) "."

command ::= "Turn " ("on" | "off") " " device | "Set " device " to " value |
            "Increase " device " by " value | "Decrease " device " by " value |
            "Play " media | "Stop " media | "Schedule " task " at " time | "Cancel " task |
            "Remind me to " task " at " time | "Show me " device | "Hide " device

question ::= "What is the " device " status?" | "What is the current " device " value?" |
             "What is the " device " temperature?" | "What is the " device " humidity?" |
             "What is the " device " power consumption?" | "What is the " device " battery level?" |
             "What is the weather like today?" | "What is the forecast for tomorrow?" |
             "What is the time?" | "What is my schedule for today?" | "What tasks do I have?" |
             "What reminders do I have?"

device ::= "lights" | "thermostat" | "security system" | "door lock" | "camera" | "speaker" | "TV" |
           "music player" | "coffee machine" | "oven" | "refrigerator" | "washing machine" |
           "vacuum cleaner"

value ::= [0-9]+

media ::= "music" | "radio" | "podcast" | "audiobook" | "TV show" | "movie"

task ::= [a-zA-Z]+ (" " [a-zA-Z]+)?

time ::= [0-9] [0-9] ":" [0-9] [0-9] ("am" | "pm")?

- option to read grammar from file
- add sample grammars for colors and chess moves
- fine-tune the performance further
@ggerganov
Copy link
Author

I've been testing with the following 3 commands on this branch and I think the results are really good:

./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red green blue"

# say stuff like "d4 d5 knight to c3"
./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "pawn knight king a1 f5 h6"

./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/assistant.gbnf --prompt "Ok Whisper, start listening for commands" 

The one thing that is missing is to somehow filter "false positives". I.e. when you say something completely outside the grammar, to be able to understand that it is not related.

@ejones
Copy link
Owner

ejones commented Sep 10, 2023

The one thing that is missing is to somehow filter "false positives". I.e. when you say something completely outside the grammar, to be able to understand that it is not related.

Yeah, that was my struggle exactly. An early thought I had, which I haven't tried yet, was to abandon grammar sampling if all the tokens suggested by the grammar seem sufficiently unlikely. Perhaps some logit threshold that a token would have to satisfy, or top-k/top-p rank?

@ggerganov ggerganov marked this pull request as ready for review September 10, 2023 10:52
@ggerganov
Copy link
Author

I did a few more updates to the implementation and using the provided tests I get satisfying results. It's not yet "magical" as I imagined it could be (at least with tiny that is), but maybe we will get there after some more iterations :)

  • The allows_eot functionality that I enabled earlier is actually not useful, so I disabled it again
  • Added params.no_timestamps flag to disable sampling of timestamp tokens. This is supported by OpenAI Whisper, but we didn't have it in whisper.cpp yet. I think it helps in the use case of voice commands and grammar sampling, so it is enabled in the command example
  • Added --context parameter to the command example. With this, we can provide context - i.e. previous transcription that the model will see during transcribing the voice command. This context helps Whisper to narrow down the transcription options. For example, if the context involves lower-case text, then Whisper is likely to continue transcribing with lower-case.
  • Fixed a bug in Top-K sampling when using Beam Search - instead of sampling, we were straight-up picking the top K samples.
  • command now uses Beam Search by default with temperature of 0.4

After these changes, I am using the following commands:

# chess, use big penalty to follow more strictly the grammar
./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "rook to b4, f3," --context "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8," --grammar-penalty 100

# assistant, smaller penalty to allow some improvisations
./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/assistant.gbnf --prompt "Ok Whisper, start listening for commands." --context "Whisper is a home assistant. It recognizes voice commands. Time is 11pm." --grammar-penalty 10

Currently, if I say something completely wrong that does not match the grammar, it usually transcribes an empty string which is good.

The performance due to the Beam Search is a bit worse, but I think we will improve it from master.
There is also a sample probability estimate for the recognized voice command, but it does not work very well.

Let me know if you have any concerns with these changes or if your tests no longer work as expected.
If it's all good, let's merge this and I will proceed with moving the grammar parser into the library and then merge into master.

@ejones
Copy link
Owner

ejones commented Sep 10, 2023

That all sounds good. I'll merge now and will be able to test this evening.

@ejones ejones merged commit de7021f into ejones:grammar Sep 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants