[Whisper Tokenizer] Make decoding faster after adding timestamps #26299
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Following the update to the Whisper tokenizer to handle encoding/decoding timestamps (#26054), there is one line in the decoding which takes extremely long:
transformers/src/transformers/models/whisper/tokenization_whisper.py
Line 614 in f94c9b3
Here we do an order
N * M
operation to filter out all the timestamp tokens, whereN
is the length of the token ids, andM
the number of timestamp tokens (for each token, check whether it’s in the timestamp token list).In practice, this is causing decoding to take extremely long for typical validation sets, e.g. LibriSpeech test clean took ~30 mins for the tokenizer to decode on a TPU v3 (which has lots of CPU power to run this operation).
This PR switches the timestamp filtering to a regex string operation, which in a toy benchmark was a factor of > 2000 faster. Would love to hear from @ArthurZucker whether we're happy to sacrifice a bit of readability for this speed-up!