-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tf timestamps whisper + update generate support #21334
Tf timestamps whisper + update generate support #21334
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few suggestions 🙌
Awesome thanks for the review 🤗 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
lmk when you want to pick this up again :P Meanwhile, shall we add the WIP label, so that the bot doesn't ping us? |
yes! Hahah sorry, maybe next week or 2 weeks from now ! |
Co-authored-by: Joao Gante <joao@huggingface.co>
Okay! Thanks to @gante's recommendations, the xla generation works perfectly! The slow timestamp processing test also passes 🥳 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! 🔥
I think we still have the Beam Search XLA Whisper problem -- I'm going to prioritize it now so we can announce this cool feature soon!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this! Super exciting to have this in the other frameworks ⭐
Just a few small comments asking checks on inputs
@@ -1000,7 +1000,7 @@ def generate( | |||
doc_scores=None, | |||
n_docs=None, | |||
generation_config=None, | |||
logits_processor=TFLogitsProcessorList(), | |||
logits_processor: Optional[TFLogitsProcessorList] = TFLogitsProcessorList(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it optional in the sense it can be a None value here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! this follows the same pattern as with torch
where usually you don't give a logit processor
forced_decoder_ids = [] | ||
if task is not None or language is not None: | ||
if hasattr(generation_config, "language"): | ||
if generation_config.language in generation_config.lang_to_id.keys(): | ||
language_token = generation_config.language | ||
elif generation_config.language in TO_LANGUAGE_CODE.keys(): | ||
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" | ||
else: | ||
raise ValueError( | ||
f"Unsupported language: {self.language}. Language should be one of:" | ||
f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}." | ||
) | ||
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) | ||
else: | ||
forced_decoder_ids.append((1, None)) # automatically detect the language | ||
|
||
if hasattr(generation_config, "task"): | ||
if generation_config.task in TASK_IDS: | ||
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) | ||
else: | ||
raise ValueError( | ||
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" | ||
) | ||
else: | ||
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe | ||
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: | ||
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 | ||
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) | ||
|
||
# Legacy code for backward compatibility | ||
elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: | ||
forced_decoder_ids = self.config.forced_decoder_ids | ||
elif ( | ||
hasattr(self.generation_config, "forced_decoder_ids") | ||
and self.generation_config.forced_decoder_ids is not None | ||
): | ||
forced_decoder_ids = self.generation_config.forced_decoder_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add an extra check on the inputs to prevent this function having silent, unexpected behaviour - specifically forced_decoder_ids
from the config overriding the language
or task
arguments
forced_decoder_ids = [] | |
if task is not None or language is not None: | |
if hasattr(generation_config, "language"): | |
if generation_config.language in generation_config.lang_to_id.keys(): | |
language_token = generation_config.language | |
elif generation_config.language in TO_LANGUAGE_CODE.keys(): | |
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" | |
else: | |
raise ValueError( | |
f"Unsupported language: {self.language}. Language should be one of:" | |
f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}." | |
) | |
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) | |
else: | |
forced_decoder_ids.append((1, None)) # automatically detect the language | |
if hasattr(generation_config, "task"): | |
if generation_config.task in TASK_IDS: | |
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) | |
else: | |
raise ValueError( | |
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" | |
) | |
else: | |
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe | |
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: | |
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 | |
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) | |
# Legacy code for backward compatibility | |
elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: | |
forced_decoder_ids = self.config.forced_decoder_ids | |
elif ( | |
hasattr(self.generation_config, "forced_decoder_ids") | |
and self.generation_config.forced_decoder_ids is not None | |
): | |
forced_decoder_ids = self.generation_config.forced_decoder_ids | |
forced_decoder_ids = [] | |
legacy_forced_decoder_ids = [] | |
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: | |
legacy_forced_decoder_ids = self.config.forced_decoder_ids | |
elif ( | |
hasattr(self.generation_config, "forced_decoder_ids") | |
and self.generation_config.forced_decoder_ids is not None | |
): | |
legacy_forced_decoder_ids = self.generation_config.forced_decoder_ids | |
if task is not None or language is not None: | |
if legacy_forced_decoder_ids: | |
raise ValueError( | |
"Cannot specify language or task if forced_decoder_ids in model config or generation_config is set. " | |
"Please remove forced_decoder_ids from config file(s)" | |
) | |
if hasattr(generation_config, "language"): | |
if generation_config.language in generation_config.lang_to_id.keys(): | |
language_token = generation_config.language | |
elif generation_config.language in TO_LANGUAGE_CODE.keys(): | |
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" | |
else: | |
raise ValueError( | |
f"Unsupported language: {self.language}. Language should be one of:" | |
f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}." | |
) | |
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) | |
else: | |
forced_decoder_ids.append((1, None)) # automatically detect the language | |
if hasattr(generation_config, "task"): | |
if generation_config.task in TASK_IDS: | |
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) | |
else: | |
raise ValueError( | |
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" | |
) | |
else: | |
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe | |
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: | |
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 | |
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) | |
# Legacy code for backward compatibility | |
elif legacy_forced_decoder_ids: | |
forced_decoder_ids = legacy_forced_decoder_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For backward compatibility reasons, we had to make sure that the config
's forced_decoder_ids
. TF would probably also suffer from this 😞
self.timestamp_begin = generate_config.no_timestamps_token_id + 1 | ||
|
||
self.begin_index = len(generate_config.forced_decoder_ids) + 2 | ||
if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is forced_decoder_ids
always going to be set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has to! In the generate it will always be given a default value, but we can raise an error for easier debugging
|
||
forced_decoder_ids = [] | ||
|
||
if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual: | ||
if task is not None or language is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same notes as in modeling_tf_whisper
about checking the inputs
new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates) | ||
return new_scores | ||
|
||
def _force_tokens(current_tokens): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: My understanding is this function only works with a single token (not tokens)
def _force_tokens(current_tokens): | |
def _force_token(current_token): |
if generation_config.return_timestamps: | ||
if logits_processor is not None: | ||
logits_processor += [TFWhisperTimeStampLogitsProcessor(generation_config)] | ||
else: | ||
logits_processor = [TFWhisperTimeStampLogitsProcessor(generation_config)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should there be a default logits_processor
set here if generation_config.return_timestamps
is False
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, if we don't return timestamps we just don't need an additional processor. (but we should make sure that the forced_decoder_ids has the no_timestamp_token
forced
|
||
class TFWhisperTimeStampLogitsProcessor(TFLogitsProcessor): | ||
r""" | ||
Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you expand this a little to specify what special about it being a timestamp logits processor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will also update the Pytorch definition to make it clearer!
return scores | ||
# len(seq) == 1 corresponds to cur_len == self.begin_index + 2 | ||
last_was_timestamp = (cur_len >= self.begin_index + 2) & (input_ids[:, cur_len - 1] >= self.timestamp_begin) | ||
penultimate_was_timestamp = (cur_len < self.begin_index + 3) | ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my understanding is the cur_len < self.begin_index + 3
line because the first timestamp won't have a previous timestamp token prediction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! I can add a comment for this 😉
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Thanks for your review, will adresse all of this |
@ArthurZucker I was testing out if I get the timestamps with TF model with your
|
Hey! That’s probably because I haven’t pull from main for a while and we changed the whisper tokenizer. As you can see the decoding process is the one failing here |
@ArthurZucker Thanks for the response. I got the issue resolved with
i.e. changing |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR
This PR updates the way we generation TF and FLAX to fix the breaking changes that we had.
It also adds support for the timestamps in
TF
.Follows #21965