Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [Whisper] Fix generate and tokenizer behavior with added tokens #33512

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

eustlb
Copy link
Contributor

@eustlb eustlb commented Sep 16, 2024

What does this PR do?

Fixes #33082.

Description of the issue

Transformers' PretTrainedTokenizer _add_tokens add tokens at the end of the vocabulary. Likewise, PreTrainedModel's _get_resized_embeddings will add newly initialized tokens at the end. This behavior is not compatible with Whisper's timestamp token identification. Indeed, official implementation as well as Transformer's one identifies timestamp tokens as the ones that have an id > timestamp_begin. For this reason, the newly added tokens are falsely considered timestamps and this poses decoding issues.

Implementation possibilities

Two possibilities here:

  1. change Whisper's decoding logic using a timestamp_end
  2. overwrite _add_tokens and _get_resized_embeddings in order to add new tokens before the first timestamp token (and not at the end as current implementation)

I chose option 1 to avoid overwriting the default methods of the Transformers library, focusing instead on modifying the Whisper-specific generation method and tokenizer logic.

Implementation decision (see below discussion for details)

In the updated implementation, Whisper will now use both timestamp_begin and timestamp_end, instead of just timestamp_begin as in OpenAI’s original version.

→ Justification:

what is the best for the users (least amount of work for them)

We have two scenarios:

  1. Adding text tokens.
  2. Adding timestamp tokens.

These cases overlap, as both types of tokens must be added via the add_tokens method. For the system to work seamlessly:

  • Text tokens should be inserted before the first timestamp token (Case 1).
  • Timestamp tokens should be added at the end of the vocabulary (Case 2).

To handle Case 1, modifying the add_tokens method would also require changes to the resize_token_embedding. From a simplicity perspective, deviating from the standard implementation of these methods complicates things unnecessarily.

Since Case 1 is more common, we should prioritize simplifying it for the user. Hardcoding a default number of timestamp tokens achieves this, and the number_timestamp_tokens config option easily accommodates Case 2 for the few users who need it.

As for the tokenizer, the regex pattern currently used to detect timestamps prevents the use of custom timestamp tokens (meaning with a custom syntax not matching the pattern). However, since this is a rare use case and the current implementation already depends on regex, it’s best to keep things simple and do it this way.

Who can review?

@ylacombe

TODO

  • discuss different implementations and choose
  • add tests before merging

@eustlb eustlb changed the title Fix whisper generate added tokens [Whisper] Fix generate and tokenizer behavior with added tokens Sep 16, 2024
Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR @eustlb, it's already in great shape!

I left a few comments here and there, but the most important ones are on the tokenizer side.

What you proposed looks great to me, the only downside that I see is that users would have to be careful on how they add new tokens, especially if they want to introduce new timestamp tokens:

  1. first they'd have to add timestamp tokens to the tokenizer and add number_timestamp_tokens to the generation_config
  2. then they can add new tokens as usual

I don't see any clear way to avoid this downside, and I don't think this will be a frequent use-case anyway, so I'm happy to keep your current way of doing it.


Besides the comments, it'd be great to add some tests to test_modeling_whisper.py and to test_tokenization_whisper.py to make sure generation and tokenization works when you add new tokens!

@@ -347,7 +347,9 @@ def generate(
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
return_timestamps (`bool`, *optional*):
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
Whether to return the timestamps with the text. This enables the `WhisperTimeStampLogitsProcessor`.
By default, Whisper uses 1501 timestamp tokens. If a custom number of timestamp tokens is needed,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice! Should we also write this somewhere in whisper.md to further highlight it since it's quite hidden from the user?

Comment on lines 332 to 335
@property
def timestamp_end(self) -> int:
timestamp_ids = [value for key, value in self.added_tokens_encoder.items() if self.timestamp_pat.match(key)]
return max(timestamp_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's on this point that I'll be the most cautious, as we introduced two ways of computing timestamp_end:

  1. In the modeling file with timestamp_begin + generation_config.get("number_timestamp_tokens", 1501)
  2. in the tokenizer, with the current way of doing it: taking the maximum id of every tokens that match the timestamp pattern.

@itazap and @ArthurZucker, WDYT of this ?

To give a bit of context, the Whisper model has its usual vocabulary on which is appended a series of timestamps token ids (usually 1501 tokens). When a user appends a new token to the tokenizer, it is wrongly identified as a timestamp token since it's id is greater than the o.g vocabulary size.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to avoid using timestamp begin and timestamp end here, and to simply verify in the rest of the file if a token is a timestamp thanks to self.timestamp_pat.match(key) ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a way to be more rigorous on how we teat timestamps in the tokenizer but would introduce a difference in how we compute timestamps in the generation file and here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up on this! I do agree that the introduced inconsistency is not convenient. Since I do not see another way of identifying timestamp_end, and that the solution I proposed is computationally unoptimal (recomputed at each call, could be improved with LRU cache), I agree that identifying timestamp tokens from the tokenizer's path would be better with a regex match. This would remove necessity of timestamp_end, nevertheless we can't avoid using timestamp_begin since it is required to computing timestamp values (here).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall let's aim for simplicity:

  • what is the best for the users (least amount of work for them)
  • what is the simplest code for this.

@eustlb eustlb force-pushed the fix-whisper-generate-added-tokens branch from 9b88b2c to e78bbed Compare September 23, 2024 15:25
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@eustlb
Copy link
Contributor Author

eustlb commented Sep 24, 2024

→ Let's wrap up a bit and justify choices further:

In the updated implementation, Whisper will now use both timestamp_begin and timestamp_end, instead of just timestamp_begin as in OpenAI’s original version.

→ Justification:

what is the best for the users (least amount of work for them)

We have two scenarios:

  1. Adding text tokens.
  2. Adding timestamp tokens.

These cases overlap, as both types of tokens must be added via the add_tokens method. For the system to work seamlessly:

  • Text tokens should be inserted before the first timestamp token (Case 1).
  • Timestamp tokens should be added at the end of the vocabulary (Case 2).

To handle Case 1, modifying the add_tokens method would also require changes to the resize_token_embedding. From a simplicity perspective, deviating from the standard implementation of these methods complicates things unnecessarily.

Since Case 1 is more common, we should prioritize simplifying it for the user. Hardcoding a default number of timestamp tokens achieves this, and the number_timestamp_tokens config option easily accommodates Case 2 for the few users who need it.

As for the tokenizer, the regex pattern currently used to detect timestamps prevents the use of custom timestamp tokens (meaning with a custom syntax not matching the pattern). However, since this is a rare use case and the current implementation already depends on regex, it’s best to keep things simple and do it this way.

@eustlb eustlb changed the title [Whisper] Fix generate and tokenizer behavior with added tokens [WIP] [Whisper] Fix generate and tokenizer behavior with added tokens Oct 13, 2024
@eustlb eustlb mentioned this pull request Oct 14, 2024
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Whisper generate return a slice of result if result have more than one added token
4 participants