-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix flax whisper tokenizer bug #33151
Fix flax whisper tokenizer bug #33151
Conversation
Fix issue with flax whisper model
Fix issue with flax whisper model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for opening a PR @hannan72!
Please make sure when opening a PR to only tag a minimal subset of the most relevant people. This ensures the PR is reviewed quickly (no by-stander effect) and is good practice and it keeps the number of notifications for everyone in check.
@@ -849,7 +849,7 @@ def _strip_prompt(self, token_ids: List[int], prompt_token_id: int, decoder_star | |||
|
|||
# handle case of empty token_ids for decoding with timestamps. | |||
# at this point token_ids is a list, so it is safe to use if not check. | |||
if not token_ids: | |||
if token_ids is None or len(token_ids) == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't match with the comment above -- which indicates unexpected behaviour in the _convert_to_list
method.
The linked issue indicates a problem with checking not
on np arrays, but this doesn't correspond to the None
check here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good points!
I also updated the comment
Please take a look again @amyeroberts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment still doesn't make sense wrt the change. If token_ids
really is a list, then checking not token_ids
should be safe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts The problem is that token_ids is not a list for flax models, but it is a jax array. So that it raises error for not token_ids
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts Do you have any other suggestion for resolving the issue for flax models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main thing is that the function names and comments should be consistent with the logic. So in the comment above it says the object is a list, which appears not to be true. In addition, the issue indicates _convert_to_list
isn't actually converting to a list. So the issues to resolve is why isn't _convert_to_list
converting to list? Should the method be updated or changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts You're right! So we should update _convert_to_list
method inorder to convert jax arrays to list as long as numpy, torch and tf arrays.
I updated the PR, reverted the changes in checking token_ids
again with not
and add a couple of lines in _convert_to_list
to convert jax array to list.
Is it OK now?
just check len of token_ids
just use len of token_ids
Any other points @amyeroberts ? |
…pt and add support to jax arrays in _convert_to_list
…d add support to jax arrays in _convert_to_list
Is it ready to merge @amyeroberts? |
@hannan72 Thanks for iterating - change now looks OK - final thing is to add a test, which would fail on current |
…odules if available
@amyeroberts Test codes has been added and passed by automatic tests. Code you please do the final review? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating and adding tests!
General comment that unrelated formatting changes should be removed from the diff. Once the tests are split up we should be good to go
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
…stead of `is_xxx_available()` method
Thanks @amyeroberts for your suggestion! It is applied and tests are split up. Just needs your approval. |
Anything else? |
@@ -204,11 +217,21 @@ def test_skip_special_tokens_skips_prompt_ids(self): | |||
# fmt: on | |||
expected_with_special_tokens = "<|startofprev|> Mr. Quilter<|startoftranscript|><|en|><|transcribe|><|notimestamps|> On the general principles of art, Mr. Quilter writes with equal lucidity.<|endoftext|>" | |||
expected_without_special_tokens = " On the general principles of art, Mr. Quilter writes with equal lucidity." | |||
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove all these changes which shouldn't be applied (our line length is 120 and this is a formatting change unrelated to the PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts All unrelated changes have been reverted. Now is it the proper time for merging the PR?
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@amyeroberts All changes has been made. I'd appreciate it if you merge the PR! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great - thanks for fixing and iterating on this!
* Update tokenization_whisper.py Fix issue with flax whisper model * Update tokenization_whisper_fast.py Fix issue with flax whisper model * Update tokenization_whisper.py just check len of token_ids * Update tokenization_whisper_fast.py just use len of token_ids * Update tokenization_whisper_fast.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list * Update tokenization_whisper.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list * Update test_tokenization_whisper.py to add test for _convert_to_list method * Update test_tokenization_whisper.py to fix code style issues * Fix code style * Fix code check again * Update test_tokenization)whisper.py to Improve code style * Update test_tokenization_whisper.py to run each of jax, tf and flax modules if available * Update tests/models/whisper/test_tokenization_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update test_tokenization_whisper.py and use require_xxx decorators instead of `is_xxx_available()` method * Revert the changes automatically applied by formatter and was unrelated to PR * Format for minimal changes --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Update tokenization_whisper.py Fix issue with flax whisper model * Update tokenization_whisper_fast.py Fix issue with flax whisper model * Update tokenization_whisper.py just check len of token_ids * Update tokenization_whisper_fast.py just use len of token_ids * Update tokenization_whisper_fast.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list * Update tokenization_whisper.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list * Update test_tokenization_whisper.py to add test for _convert_to_list method * Update test_tokenization_whisper.py to fix code style issues * Fix code style * Fix code check again * Update test_tokenization)whisper.py to Improve code style * Update test_tokenization_whisper.py to run each of jax, tf and flax modules if available * Update tests/models/whisper/test_tokenization_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update test_tokenization_whisper.py and use require_xxx decorators instead of `is_xxx_available()` method * Revert the changes automatically applied by formatter and was unrelated to PR * Format for minimal changes --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Update tokenization_whisper.py Fix issue with flax whisper model * Update tokenization_whisper_fast.py Fix issue with flax whisper model * Update tokenization_whisper.py just check len of token_ids * Update tokenization_whisper_fast.py just use len of token_ids * Update tokenization_whisper_fast.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list * Update tokenization_whisper.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list * Update test_tokenization_whisper.py to add test for _convert_to_list method * Update test_tokenization_whisper.py to fix code style issues * Fix code style * Fix code check again * Update test_tokenization)whisper.py to Improve code style * Update test_tokenization_whisper.py to run each of jax, tf and flax modules if available * Update tests/models/whisper/test_tokenization_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update test_tokenization_whisper.py and use require_xxx decorators instead of `is_xxx_available()` method * Revert the changes automatically applied by formatter and was unrelated to PR * Format for minimal changes --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Update tokenization_whisper.py Fix issue with flax whisper model * Update tokenization_whisper_fast.py Fix issue with flax whisper model * Update tokenization_whisper.py just check len of token_ids * Update tokenization_whisper_fast.py just use len of token_ids * Update tokenization_whisper_fast.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list * Update tokenization_whisper.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list * Update test_tokenization_whisper.py to add test for _convert_to_list method * Update test_tokenization_whisper.py to fix code style issues * Fix code style * Fix code check again * Update test_tokenization)whisper.py to Improve code style * Update test_tokenization_whisper.py to run each of jax, tf and flax modules if available * Update tests/models/whisper/test_tokenization_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update test_tokenization_whisper.py and use require_xxx decorators instead of `is_xxx_available()` method * Revert the changes automatically applied by formatter and was unrelated to PR * Format for minimal changes --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
What does this PR do?
Fixes Bug when using whisper tokenizer for flax whisper model, according to the issue #32936
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker
@sanchit-gandhi