Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flax whisper tokenizer bug #33151

Merged
merged 16 commits into from
Sep 12, 2024

Conversation

hannan72
Copy link
Contributor

@hannan72 hannan72 commented Aug 27, 2024

What does this PR do?

Fixes Bug when using whisper tokenizer for flax whisper model, according to the issue #32936

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker
@sanchit-gandhi

Fix issue with flax whisper model
Fix issue with flax whisper model
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for opening a PR @hannan72!

Please make sure when opening a PR to only tag a minimal subset of the most relevant people. This ensures the PR is reviewed quickly (no by-stander effect) and is good practice and it keeps the number of notifications for everyone in check.

@@ -849,7 +849,7 @@ def _strip_prompt(self, token_ids: List[int], prompt_token_id: int, decoder_star

# handle case of empty token_ids for decoding with timestamps.
# at this point token_ids is a list, so it is safe to use if not check.
if not token_ids:
if token_ids is None or len(token_ids) == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't match with the comment above -- which indicates unexpected behaviour in the _convert_to_list method.

The linked issue indicates a problem with checking not on np arrays, but this doesn't correspond to the None check here

Copy link
Contributor Author

@hannan72 hannan72 Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points!
I also updated the comment
Please take a look again @amyeroberts

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment still doesn't make sense wrt the change. If token_ids really is a list, then checking not token_ids should be safe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts The problem is that token_ids is not a list for flax models, but it is a jax array. So that it raises error for not token_ids

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts Do you have any other suggestion for resolving the issue for flax models?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main thing is that the function names and comments should be consistent with the logic. So in the comment above it says the object is a list, which appears not to be true. In addition, the issue indicates _convert_to_list isn't actually converting to a list. So the issues to resolve is why isn't _convert_to_list converting to list? Should the method be updated or changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts You're right! So we should update _convert_to_list method inorder to convert jax arrays to list as long as numpy, torch and tf arrays.
I updated the PR, reverted the changes in checking token_ids again with not and add a couple of lines in _convert_to_list to convert jax array to list.
Is it OK now?

just check len of token_ids
just use len of token_ids
@hannan72
Copy link
Contributor Author

Any other points @amyeroberts ?

…pt and add support to jax arrays in _convert_to_list
…d add support to jax arrays in _convert_to_list
@hannan72
Copy link
Contributor Author

hannan72 commented Sep 2, 2024

Is it ready to merge @amyeroberts?

@amyeroberts
Copy link
Collaborator

@hannan72 Thanks for iterating - change now looks OK - final thing is to add a test, which would fail on current main but passes with this fix

@hannan72
Copy link
Contributor Author

hannan72 commented Sep 5, 2024

@amyeroberts Test codes has been added and passed by automatic tests. Code you please do the final review?

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating and adding tests!

General comment that unrelated formatting changes should be removed from the diff. Once the tests are split up we should be good to go

tests/models/whisper/test_tokenization_whisper.py Outdated Show resolved Hide resolved
hannan72 and others added 2 commits September 7, 2024 05:17
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@hannan72
Copy link
Contributor Author

hannan72 commented Sep 7, 2024

Thanks @amyeroberts for your suggestion! It is applied and tests are split up. Just needs your approval.

@hannan72
Copy link
Contributor Author

hannan72 commented Sep 9, 2024

Thanks for iterating and adding tests!

General comment that unrelated formatting changes should be removed from the diff. Once the tests are split up we should be good to go

Anything else?

@@ -204,11 +217,21 @@ def test_skip_special_tokens_skips_prompt_ids(self):
# fmt: on
expected_with_special_tokens = "<|startofprev|> Mr. Quilter<|startoftranscript|><|en|><|transcribe|><|notimestamps|> On the general principles of art, Mr. Quilter writes with equal lucidity.<|endoftext|>"
expected_without_special_tokens = " On the general principles of art, Mr. Quilter writes with equal lucidity."
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove all these changes which shouldn't be applied (our line length is 120 and this is a formatting change unrelated to the PR)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts All unrelated changes have been reverted. Now is it the proper time for merging the PR?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@hannan72
Copy link
Contributor Author

@amyeroberts All changes has been made. I'd appreciate it if you merge the PR!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great - thanks for fixing and iterating on this!

@amyeroberts amyeroberts merged commit 8ed6352 into huggingface:main Sep 12, 2024
18 checks passed
@hannan72 hannan72 deleted the fix_flax_whisper_tokenizer_bug branch September 12, 2024 15:03
itazap pushed a commit to NielsRogge/transformers that referenced this pull request Sep 20, 2024
* Update tokenization_whisper.py

Fix issue with flax whisper model

* Update tokenization_whisper_fast.py

Fix issue with flax whisper model

* Update tokenization_whisper.py

just check len of token_ids

* Update tokenization_whisper_fast.py

just use len of token_ids

* Update tokenization_whisper_fast.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list

* Update tokenization_whisper.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list

* Update test_tokenization_whisper.py to add test for _convert_to_list method

* Update test_tokenization_whisper.py to fix code style issues

* Fix code style

* Fix code check again

* Update test_tokenization)whisper.py to Improve code style

* Update test_tokenization_whisper.py to run each of jax, tf and flax modules if available

* Update tests/models/whisper/test_tokenization_whisper.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update test_tokenization_whisper.py and use require_xxx decorators instead of `is_xxx_available()` method

* Revert the changes automatically applied by formatter and was unrelated to PR

* Format for minimal changes

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
amyeroberts added a commit to amyeroberts/transformers that referenced this pull request Oct 2, 2024
* Update tokenization_whisper.py

Fix issue with flax whisper model

* Update tokenization_whisper_fast.py

Fix issue with flax whisper model

* Update tokenization_whisper.py

just check len of token_ids

* Update tokenization_whisper_fast.py

just use len of token_ids

* Update tokenization_whisper_fast.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list

* Update tokenization_whisper.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list

* Update test_tokenization_whisper.py to add test for _convert_to_list method

* Update test_tokenization_whisper.py to fix code style issues

* Fix code style

* Fix code check again

* Update test_tokenization)whisper.py to Improve code style

* Update test_tokenization_whisper.py to run each of jax, tf and flax modules if available

* Update tests/models/whisper/test_tokenization_whisper.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update test_tokenization_whisper.py and use require_xxx decorators instead of `is_xxx_available()` method

* Revert the changes automatically applied by formatter and was unrelated to PR

* Format for minimal changes

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* Update tokenization_whisper.py

Fix issue with flax whisper model

* Update tokenization_whisper_fast.py

Fix issue with flax whisper model

* Update tokenization_whisper.py

just check len of token_ids

* Update tokenization_whisper_fast.py

just use len of token_ids

* Update tokenization_whisper_fast.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list

* Update tokenization_whisper.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list

* Update test_tokenization_whisper.py to add test for _convert_to_list method

* Update test_tokenization_whisper.py to fix code style issues

* Fix code style

* Fix code check again

* Update test_tokenization)whisper.py to Improve code style

* Update test_tokenization_whisper.py to run each of jax, tf and flax modules if available

* Update tests/models/whisper/test_tokenization_whisper.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update test_tokenization_whisper.py and use require_xxx decorators instead of `is_xxx_available()` method

* Revert the changes automatically applied by formatter and was unrelated to PR

* Format for minimal changes

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* Update tokenization_whisper.py

Fix issue with flax whisper model

* Update tokenization_whisper_fast.py

Fix issue with flax whisper model

* Update tokenization_whisper.py

just check len of token_ids

* Update tokenization_whisper_fast.py

just use len of token_ids

* Update tokenization_whisper_fast.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list

* Update tokenization_whisper.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list

* Update test_tokenization_whisper.py to add test for _convert_to_list method

* Update test_tokenization_whisper.py to fix code style issues

* Fix code style

* Fix code check again

* Update test_tokenization)whisper.py to Improve code style

* Update test_tokenization_whisper.py to run each of jax, tf and flax modules if available

* Update tests/models/whisper/test_tokenization_whisper.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update test_tokenization_whisper.py and use require_xxx decorators instead of `is_xxx_available()` method

* Revert the changes automatically applied by formatter and was unrelated to PR

* Format for minimal changes

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants