-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix prompt strip to support tensors and np arrays #27818
Conversation
I was not able to make it pass the CI/CD tests due to import issues. @sanchit-gandhi can you please guide? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @AvivSham! Thanks for opening this PR - the first draft looks in decent nick!
One thing we need to change is how we check for Torch/TF tensors: the feature extraction file should be agnostic to framework. Imagine we have a user running Whisper in PyTorch. We do not force them to have TF installed, since there's no need to have both PyTorch and TF installed if just using PyTorch. Therefore, we cannot assume TF is an available import, and thus can't leverage it in the feature extraction file. The converse is true for someone using TF, where we don't expect them to have PyTorch installed.
=> what we need to do here is refactor the changes such that they are agnostic to framework. I've left some hints on how to do this below! Let me know if you have any questions - happy to help!
if has_prompt: | ||
if not isinstance(token_ids, list): | ||
token_ids = self._convert_to_list(token_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's convert the token ids to a list before checking whether we have a prompt. If we convert the token ids to a list first, then we don't need to check:
isinstance(token_ids, (List[int], np.ndarray, torch.Tensor, tf.Tensor))
=> since we know token_ids
will be a list!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still valid! We need to convert to list before checking whether we have a prompt
Thank you for reviewing @sanchit-gandhi ! |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
This PR is not stale, waiting for @sanchit-gandhi to approve. |
Gentle ping @sanchit-gandhi |
@amyeroberts @sanchit-gandhi |
@amyeroberts This issue is not stale, still waiting for @sanchit-gandhi. |
@sanchit-gandhi @ylacombe Can one of you review this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the last round of commits @AvivSham! Super sorry for the late review here
I would recommend using the following script to quickly debug your changes and confirm they're working (currently the following code fails since we don't strip the prompt for numpy
inputs):
import numpy as np
from transformers import WhisperTokenizer
prompt_text = "the cat sat on the mat"
input_text = "the quick brown fox"
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
prompt_ids = tokenizer.get_prompt_ids("the cat sat on the mat", return_tensors="np")
input_ids = tokenizer("the quick brown fox", return_tensors="np").input_ids[0]
input_ids = np.hstack([prompt_ids, input_ids])
# check with prompt in output
pred_text = tokenizer.decode(input_ids, skip_special_tokens=False)
assert pred_text.strip() == prompt_text + input_text
# check stripping prompt from output
pred_text = tokenizer.decode(input_ids, skip_special_tokens=True)
assert pred_text.strip() == input_text
Along with the subtle changes to the order of operations (which will fix the above example), it would be great to add a slow/fast tokenizer test to confirm we get the correct behaviour after this change. To do this, you can add a test in the file test_tokenization_whisper.py
. I would follow the existing logic to load a pre-trained slow and fast (or "rust") tokenizers, then run a similar check to the one we have above. You can use the following test as an example:
transformers/tests/models/whisper/test_tokenization_whisper.py
Lines 269 to 271 in ce3647a
def test_basic_normalizer(self): | |
tokenizer = self.get_tokenizer() | |
rust_tokenizer = self.get_rust_tokenizer() |
Let me know if you have any questions! Happy to help with the final changes before merge!
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Hi @sanchit-gandhi, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @AvivSham! The functionality looks great here! Just a few clean-up comments below. Once done, could you run:
make style
In order to run the linter to format the code? If you hit issues regarding missing packages, you can install them with the following from the root of the Transformers repo:
pip install -e ".[quality]"
Many thanks!
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
@sanchit-gandhi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like there was a minor typo (suggested a fix below), which running make fix-copies
should check is correct. By the way, there's great documentation on the style checks we use in Transformers in this contributing guide.
It looks like a Whisper tokenizer test is failing on the CI for this PR (traceback). You can run the test locally from the root of the repo with the following command:
pytest -sv tests/models/whisper/test_tokenization_whisper.py::WhisperTokenizerTest::test_skip_special_tokens_with_timestamps
Could you have a look to see why this test is failing after these PR changes? Happy to help if it's unclear! If you have any dependency issues when running the test, you can install the necessary packages with:
pip install -e ".[quality]"
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Hi @sanchit-gandhi , if not isinstance(token_ids, int):
if len(token_ids) == 0:
raise ValueError("No token ids provided for decoding.") or if you see a reason for decoding an empty string we can return |
Hey @AvivSham - sorry I don't fully follow! If we look at the test test_skip_special_tokens_with_timestamps, we see that we always pass a non-empty array of token-ids to the tokeniser's |
Ok, so I debugged it and the empty |
Hi @AvivSham, you need to run cc @kamilakesbi For first review as @sanchit-gandhi is off |
@amyeroberts I ran it and committed. Let's see if the CI/CD pass this time. |
@amyeroberts @kamilakesbi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @AvivSham, thanks for working on this!
It LGTM! we can merge this PR @amyeroberts :)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding! Just one comment
Hi @amyeroberts, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding!
@amyeroberts I keep receiving emails regarding CI/CD failed runs how can we prevent that? I thought it would stop once the PR is merged. |
Hi @AvivSham, this is because of the settings on your fork. You can disable the CI runs by going to the repo, then Settings > Actions > Disable actions |
@amyeroberts Thank you very much for helping. |
* fix prompt strip to support tensors and np arrays * framework agnostic * change logic check before converting prompt into list Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * adding _convert_to_list to tokenization_whisper_fast * adding tests for prompt decoding * adding comment Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * adding comment Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * revert minor * make style formatting * style formatting after update * Update src/transformers/models/whisper/tokenization_whisper_fast.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * fixing _strip_prompt to handle _decode_with_timestamps * fix copies --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* fix prompt strip to support tensors and np arrays * framework agnostic * change logic check before converting prompt into list Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * adding _convert_to_list to tokenization_whisper_fast * adding tests for prompt decoding * adding comment Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * adding comment Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * revert minor * make style formatting * style formatting after update * Update src/transformers/models/whisper/tokenization_whisper_fast.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * fixing _strip_prompt to handle _decode_with_timestamps * fix copies --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* fix prompt strip to support tensors and np arrays * framework agnostic * change logic check before converting prompt into list Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * adding _convert_to_list to tokenization_whisper_fast * adding tests for prompt decoding * adding comment Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * adding comment Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * revert minor * make style formatting * style formatting after update * Update src/transformers/models/whisper/tokenization_whisper_fast.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * fixing _strip_prompt to handle _decode_with_timestamps * fix copies --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* fix prompt strip to support tensors and np arrays * framework agnostic * change logic check before converting prompt into list Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * adding _convert_to_list to tokenization_whisper_fast * adding tests for prompt decoding * adding comment Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * adding comment Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * revert minor * make style formatting * style formatting after update * Update src/transformers/models/whisper/tokenization_whisper_fast.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * fixing _strip_prompt to handle _decode_with_timestamps * fix copies --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
What does this PR do?
WhisperTokenizer
does not strip the prompt ids when callingdecode
fortorch.Tensor
,tf.Tensor
, andnp.ndarray
. It does not perform the slicing in_strip_prompt
because the conditionisinstance(token_ids, list)
is not met.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sanchit-gandhi
@ArthurZucker