-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] [Whisper] Fix generate and tokenizer behavior with added tokens #33512
base: main
Are you sure you want to change the base?
[WIP] [Whisper] Fix generate and tokenizer behavior with added tokens #33512
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR @eustlb, it's already in great shape!
I left a few comments here and there, but the most important ones are on the tokenizer side.
What you proposed looks great to me, the only downside that I see is that users would have to be careful on how they add new tokens, especially if they want to introduce new timestamp tokens:
- first they'd have to add timestamp tokens to the tokenizer and add
number_timestamp_tokens
to thegeneration_config
- then they can add new tokens as usual
I don't see any clear way to avoid this downside, and I don't think this will be a frequent use-case anyway, so I'm happy to keep your current way of doing it.
Besides the comments, it'd be great to add some tests to test_modeling_whisper.py
and to test_tokenization_whisper.py
to make sure generation and tokenization works when you add new tokens!
@@ -347,7 +347,9 @@ def generate( | |||
synced_gpus (`bool`, *optional*, defaults to `False`): | |||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) | |||
return_timestamps (`bool`, *optional*): | |||
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`. | |||
Whether to return the timestamps with the text. This enables the `WhisperTimeStampLogitsProcessor`. | |||
By default, Whisper uses 1501 timestamp tokens. If a custom number of timestamp tokens is needed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is nice! Should we also write this somewhere in whisper.md
to further highlight it since it's quite hidden from the user?
@property | ||
def timestamp_end(self) -> int: | ||
timestamp_ids = [value for key, value in self.added_tokens_encoder.items() if self.timestamp_pat.match(key)] | ||
return max(timestamp_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's on this point that I'll be the most cautious, as we introduced two ways of computing timestamp_end
:
- In the modeling file with
timestamp_begin + generation_config.get("number_timestamp_tokens", 1501)
- in the tokenizer, with the current way of doing it: taking the maximum id of every tokens that match the timestamp pattern.
@itazap and @ArthurZucker, WDYT of this ?
To give a bit of context, the Whisper model has its usual vocabulary on which is appended a series of timestamps token ids (usually 1501 tokens). When a user appends a new token to the tokenizer, it is wrongly identified as a timestamp token since it's id is greater than the o.g vocabulary size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to avoid using timestamp begin and timestamp end here, and to simply verify in the rest of the file if a token is a timestamp thanks to self.timestamp_pat.match(key)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be a way to be more rigorous on how we teat timestamps in the tokenizer but would introduce a difference in how we compute timestamps in the generation file and here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Up on this! I do agree that the introduced inconsistency is not convenient. Since I do not see another way of identifying timestamp_end
, and that the solution I proposed is computationally unoptimal (recomputed at each call, could be improved with LRU cache), I agree that identifying timestamp tokens from the tokenizer's path would be better with a regex match. This would remove necessity of timestamp_end
, nevertheless we can't avoid using timestamp_begin
since it is required to computing timestamp values (here).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall let's aim for simplicity:
- what is the best for the users (least amount of work for them)
- what is the simplest code for this.
9b88b2c
to
e78bbed
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
→ Let's wrap up a bit and justify choices further: In the updated implementation, Whisper will now use both → Justification: what is the best for the users (least amount of work for them) We have two scenarios:
These cases overlap, as both types of tokens must be added via the add_tokens method. For the system to work seamlessly:
To handle Case 1, modifying the Since Case 1 is more common, we should prioritize simplifying it for the user. Hardcoding a default number of timestamp tokens achieves this, and the As for the tokenizer, the regex pattern currently used to detect timestamps prevents the use of custom timestamp tokens (meaning with a custom syntax not matching the pattern). However, since this is a rare use case and the current implementation already depends on regex, it’s best to keep things simple and do it this way. |
What does this PR do?
Fixes #33082.
Description of the issue
Transformers'
PretTrainedTokenizer
_add_tokens
add tokens at the end of the vocabulary. Likewise,PreTrainedModel
's_get_resized_embeddings
will add newly initialized tokens at the end. This behavior is not compatible with Whisper's timestamp token identification. Indeed, official implementation as well as Transformer's one identifies timestamp tokens as the ones that have an id >timestamp_begin
. For this reason, the newly added tokens are falsely considered timestamps and this poses decoding issues.Implementation possibilities
Two possibilities here:
_add_tokens
and_get_resized_embeddings
in order to add new tokens before the first timestamp token (and not at the end as current implementation)I chose option 1 to avoid overwriting the default methods of the Transformers library, focusing instead on modifying the Whisper-specific generation method and tokenizer logic.
Implementation decision (see below discussion for details)
In the updated implementation, Whisper will now use both
timestamp_begin
andtimestamp_end
, instead of justtimestamp_begin
as in OpenAI’s original version.→ Justification:
what is the best for the users (least amount of work for them)
We have two scenarios:
These cases overlap, as both types of tokens must be added via the add_tokens method. For the system to work seamlessly:
To handle Case 1, modifying the
add_tokens
method would also require changes to theresize_token_embedding
. From a simplicity perspective, deviating from the standard implementation of these methods complicates things unnecessarily.Since Case 1 is more common, we should prioritize simplifying it for the user. Hardcoding a default number of timestamp tokens achieves this, and the
number_timestamp_tokens
config option easily accommodates Case 2 for the few users who need it.As for the tokenizer, the regex pattern currently used to detect timestamps prevents the use of custom timestamp tokens (meaning with a custom syntax not matching the pattern). However, since this is a rare use case and the current implementation already depends on regex, it’s best to keep things simple and do it this way.
Who can review?
@ylacombe
TODO