Skip to content

Commit

Permalink
Fix flax whisper tokenizer bug (huggingface#33151)
Browse files Browse the repository at this point in the history
* Update tokenization_whisper.py

Fix issue with flax whisper model

* Update tokenization_whisper_fast.py

Fix issue with flax whisper model

* Update tokenization_whisper.py

just check len of token_ids

* Update tokenization_whisper_fast.py

just use len of token_ids

* Update tokenization_whisper_fast.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list

* Update tokenization_whisper.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list

* Update test_tokenization_whisper.py to add test for _convert_to_list method

* Update test_tokenization_whisper.py to fix code style issues

* Fix code style

* Fix code check again

* Update test_tokenization)whisper.py to Improve code style

* Update test_tokenization_whisper.py to run each of jax, tf and flax modules if available

* Update tests/models/whisper/test_tokenization_whisper.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update test_tokenization_whisper.py and use require_xxx decorators instead of `is_xxx_available()` method

* Revert the changes automatically applied by formatter and was unrelated to PR

* Format for minimal changes

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
hannan72 and amyeroberts committed Oct 2, 2024
1 parent 5435902 commit b27f880
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,8 @@ def _convert_to_list(token_ids):
token_ids = token_ids.cpu().numpy()
elif "tensorflow" in str(type(token_ids)):
token_ids = token_ids.numpy()
elif "jaxlib" in str(type(token_ids)):
token_ids = token_ids.tolist()
# now the token ids are either a numpy array, or a list of lists
if isinstance(token_ids, np.ndarray):
token_ids = token_ids.tolist()
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/whisper/tokenization_whisper_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,8 @@ def _convert_to_list(token_ids):
token_ids = token_ids.cpu().numpy()
elif "tensorflow" in str(type(token_ids)):
token_ids = token_ids.numpy()
elif "jaxlib" in str(type(token_ids)):
token_ids = token_ids.tolist()
# now the token ids are either a numpy array, or a list of lists
if isinstance(token_ids, np.ndarray):
token_ids = token_ids.tolist()
Expand Down
41 changes: 40 additions & 1 deletion tests/models/whisper/test_tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
from transformers.testing_utils import slow
from transformers.testing_utils import require_flax, require_tf, require_torch, slow

from ...test_tokenization_common import TokenizerTesterMixin

Expand Down Expand Up @@ -574,3 +574,42 @@ def test_offset_decoding(self):

output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(output, [])

def test_convert_to_list_np(self):
test_list = [[1, 2, 3], [4, 5, 6]]

# Test with an already converted list
self.assertListEqual(WhisperTokenizer._convert_to_list(test_list), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(test_list), test_list)

# Test with a numpy array
np_array = np.array(test_list)
self.assertListEqual(WhisperTokenizer._convert_to_list(np_array), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(np_array), test_list)

@require_tf
def test_convert_to_list_tf(self):
import tensorflow as tf

test_list = [[1, 2, 3], [4, 5, 6]]
tf_tensor = tf.constant(test_list)
self.assertListEqual(WhisperTokenizer._convert_to_list(tf_tensor), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(tf_tensor), test_list)

@require_flax
def test_convert_to_list_jax(self):
import jax.numpy as jnp

test_list = [[1, 2, 3], [4, 5, 6]]
jax_array = jnp.array(test_list)
self.assertListEqual(WhisperTokenizer._convert_to_list(jax_array), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(jax_array), test_list)

@require_torch
def test_convert_to_list_pt(self):
import torch

test_list = [[1, 2, 3], [4, 5, 6]]
torch_tensor = torch.tensor(test_list)
self.assertListEqual(WhisperTokenizer._convert_to_list(torch_tensor), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(torch_tensor), test_list)

0 comments on commit b27f880

Please sign in to comment.