Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tf timestamps whisper + update generate support #21334

Closed

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Jan 27, 2023

What does this PR

This PR updates the way we generation TF and FLAX to fix the breaking changes that we had.
It also adds support for the timestamps in TF.
Follows #21965

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@ArthurZucker ArthurZucker changed the title Tf timestamps whisper [Whisper] Tf timestamps whisper Jan 27, 2023
@ArthurZucker ArthurZucker requested a review from gante January 30, 2023 13:09
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few suggestions 🙌

@ArthurZucker
Copy link
Collaborator Author

Awesome thanks for the review 🤗

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@gante
Copy link
Member

gante commented Feb 27, 2023

lmk when you want to pick this up again :P Meanwhile, shall we add the WIP label, so that the bot doesn't ping us?

@ArthurZucker
Copy link
Collaborator Author

yes! Hahah sorry, maybe next week or 2 weeks from now !

@ArthurZucker ArthurZucker changed the title [Whisper] Tf timestamps whisper [WIP] Tf timestamps whisper Feb 27, 2023
@ArthurZucker ArthurZucker changed the title [WIP] Tf timestamps whisper [WIP] Tf timestamps whisper + update generate support Mar 14, 2023
Co-authored-by: Joao Gante <joao@huggingface.co>
@ArthurZucker
Copy link
Collaborator Author

ArthurZucker commented Mar 14, 2023

Okay! Thanks to @gante's recommendations, the xla generation works perfectly! The slow timestamp processing test also passes 🥳

@ArthurZucker ArthurZucker marked this pull request as ready for review March 14, 2023 10:17
@ArthurZucker ArthurZucker changed the title [WIP] Tf timestamps whisper + update generate support Tf timestamps whisper + update generate support Mar 14, 2023
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! 🔥

I think we still have the Beam Search XLA Whisper problem -- I'm going to prioritize it now so we can announce this cool feature soon!

@ArthurZucker ArthurZucker requested a review from sgugger March 14, 2023 12:40
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! Super exciting to have this in the other frameworks ⭐

Just a few small comments asking checks on inputs

@@ -1000,7 +1000,7 @@ def generate(
doc_scores=None,
n_docs=None,
generation_config=None,
logits_processor=TFLogitsProcessorList(),
logits_processor: Optional[TFLogitsProcessorList] = TFLogitsProcessorList(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it optional in the sense it can be a None value here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! this follows the same pattern as with torch where usually you don't give a logit processor

Comment on lines +1374 to +1410
forced_decoder_ids = []
if task is not None or language is not None:
if hasattr(generation_config, "language"):
if generation_config.language in generation_config.lang_to_id.keys():
language_token = generation_config.language
elif generation_config.language in TO_LANGUAGE_CODE.keys():
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
else:
raise ValueError(
f"Unsupported language: {self.language}. Language should be one of:"
f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}."
)
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
else:
forced_decoder_ids.append((1, None)) # automatically detect the language

if hasattr(generation_config, "task"):
if generation_config.task in TASK_IDS:
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
else:
raise ValueError(
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
)
else:
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))

# Legacy code for backward compatibility
elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
forced_decoder_ids = self.config.forced_decoder_ids
elif (
hasattr(self.generation_config, "forced_decoder_ids")
and self.generation_config.forced_decoder_ids is not None
):
forced_decoder_ids = self.generation_config.forced_decoder_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add an extra check on the inputs to prevent this function having silent, unexpected behaviour - specifically forced_decoder_ids from the config overriding the language or task arguments

Suggested change
forced_decoder_ids = []
if task is not None or language is not None:
if hasattr(generation_config, "language"):
if generation_config.language in generation_config.lang_to_id.keys():
language_token = generation_config.language
elif generation_config.language in TO_LANGUAGE_CODE.keys():
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
else:
raise ValueError(
f"Unsupported language: {self.language}. Language should be one of:"
f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}."
)
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
else:
forced_decoder_ids.append((1, None)) # automatically detect the language
if hasattr(generation_config, "task"):
if generation_config.task in TASK_IDS:
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
else:
raise ValueError(
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
)
else:
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
# Legacy code for backward compatibility
elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
forced_decoder_ids = self.config.forced_decoder_ids
elif (
hasattr(self.generation_config, "forced_decoder_ids")
and self.generation_config.forced_decoder_ids is not None
):
forced_decoder_ids = self.generation_config.forced_decoder_ids
forced_decoder_ids = []
legacy_forced_decoder_ids = []
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
legacy_forced_decoder_ids = self.config.forced_decoder_ids
elif (
hasattr(self.generation_config, "forced_decoder_ids")
and self.generation_config.forced_decoder_ids is not None
):
legacy_forced_decoder_ids = self.generation_config.forced_decoder_ids
if task is not None or language is not None:
if legacy_forced_decoder_ids:
raise ValueError(
"Cannot specify language or task if forced_decoder_ids in model config or generation_config is set. "
"Please remove forced_decoder_ids from config file(s)"
)
if hasattr(generation_config, "language"):
if generation_config.language in generation_config.lang_to_id.keys():
language_token = generation_config.language
elif generation_config.language in TO_LANGUAGE_CODE.keys():
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
else:
raise ValueError(
f"Unsupported language: {self.language}. Language should be one of:"
f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}."
)
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
else:
forced_decoder_ids.append((1, None)) # automatically detect the language
if hasattr(generation_config, "task"):
if generation_config.task in TASK_IDS:
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
else:
raise ValueError(
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
)
else:
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
# Legacy code for backward compatibility
elif legacy_forced_decoder_ids:
forced_decoder_ids = legacy_forced_decoder_ids

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For backward compatibility reasons, we had to make sure that the config's forced_decoder_ids. TF would probably also suffer from this 😞

self.timestamp_begin = generate_config.no_timestamps_token_id + 1

self.begin_index = len(generate_config.forced_decoder_ids) + 2
if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is forced_decoder_ids always going to be set?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has to! In the generate it will always be given a default value, but we can raise an error for easier debugging


forced_decoder_ids = []

if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
if task is not None or language is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same notes as in modeling_tf_whisper about checking the inputs

new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates)
return new_scores

def _force_tokens(current_tokens):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: My understanding is this function only works with a single token (not tokens)

Suggested change
def _force_tokens(current_tokens):
def _force_token(current_token):

Comment on lines +1412 to +1416
if generation_config.return_timestamps:
if logits_processor is not None:
logits_processor += [TFWhisperTimeStampLogitsProcessor(generation_config)]
else:
logits_processor = [TFWhisperTimeStampLogitsProcessor(generation_config)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a default logits_processor set here if generation_config.return_timestamps is False?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, if we don't return timestamps we just don't need an additional processor. (but we should make sure that the forced_decoder_ids has the no_timestamp_token forced


class TFWhisperTimeStampLogitsProcessor(TFLogitsProcessor):
r"""
Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you expand this a little to specify what special about it being a timestamp logits processor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will also update the Pytorch definition to make it clearer!

return scores
# len(seq) == 1 corresponds to cur_len == self.begin_index + 2
last_was_timestamp = (cur_len >= self.begin_index + 2) & (input_ids[:, cur_len - 1] >= self.timestamp_begin)
penultimate_was_timestamp = (cur_len < self.begin_index + 3) | (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding is the cur_len < self.begin_index + 3 line because the first timestamp won't have a previous timestamp token prediction?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! I can add a comment for this 😉

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@ArthurZucker
Copy link
Collaborator Author

Thanks for your review, will adresse all of this

@makaveli10
Copy link

makaveli10 commented Jun 14, 2023

@ArthurZucker I was testing out if I get the timestamps with TF model with your tf-timestamps-whisper branch on colab but I see this:

[/content/transformers/src/transformers/models/whisper/tokenization_whisper.py](https://localhost:8080/#) in decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, output_offsets, time_precision, decode_with_timestamps, **kwargs)
    593         )
    594         if decode_with_timestamps:
--> 595             text = self._decode_with_timestamps(token_ids, time_precision=time_precision)
    596         # retrieve offsets
    597         if output_offsets:

[/content/transformers/src/transformers/models/whisper/tokenization_whisper.py](https://localhost:8080/#) in _decode_with_timestamps(self, token_ids, time_precision)
    501         for token in token_ids:
    502             if token >= timestamp_begin:
--> 503                 timestamp = f"<|{(token - timestamp_begin) * time_precision:.2f}|>"
    504                 outputs.append(timestamp)
    505                 outputs.append([])

[/usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
    151     except Exception as e:
    152       filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153       raise e.with_traceback(filtered_tb) from None
    154     finally:
    155       del filtered_tb

[/usr/local/lib/python3.10/dist-packages/tensorflow/python/ops/gen_math_ops.py](https://localhost:8080/#) in mul(x, y, name)
   6574   if tld.is_eager:
   6575     try:
-> 6576       _result = pywrap_tfe.TFE_Py_FastPathExecute(
   6577         _ctx, "Mul", name, x, y)
   6578       return _result

TypeError: Cannot convert 0.02 to EagerTensor of dtype int32

@ArthurZucker
Copy link
Collaborator Author

Hey! That’s probably because I haven’t pull from main for a while and we changed the whisper tokenizer. As you can see the decoding process is the one failing here

@makaveli10
Copy link

makaveli10 commented Jun 15, 2023

@ArthurZucker Thanks for the response. I got the issue resolved with

timestamp = f"<|{float(token - timestamp_begin) * time_precision:.2f}|>"

i.e. changing token - timestamp_begin to float(token - timestamp_begin)

@github-actions
Copy link

github-actions bot commented Aug 3, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants