Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flax whisper tokenizer bug #33151

Merged
merged 16 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,8 @@ def _convert_to_list(token_ids):
token_ids = token_ids.cpu().numpy()
elif "tensorflow" in str(type(token_ids)):
token_ids = token_ids.numpy()
elif "jaxlib" in str(type(token_ids)):
token_ids = token_ids.tolist()
# now the token ids are either a numpy array, or a list of lists
if isinstance(token_ids, np.ndarray):
token_ids = token_ids.tolist()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,8 @@ def _convert_to_list(token_ids):
token_ids = token_ids.cpu().numpy()
elif "tensorflow" in str(type(token_ids)):
token_ids = token_ids.numpy()
elif "jaxlib" in str(type(token_ids)):
token_ids = token_ids.tolist()
# now the token ids are either a numpy array, or a list of lists
if isinstance(token_ids, np.ndarray):
token_ids = token_ids.tolist()
Expand Down
41 changes: 40 additions & 1 deletion tests/models/whisper/test_tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
from transformers.testing_utils import slow
from transformers.testing_utils import require_flax, require_tf, require_torch, slow

from ...test_tokenization_common import TokenizerTesterMixin

Expand Down Expand Up @@ -574,3 +574,42 @@ def test_offset_decoding(self):

output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(output, [])

def test_convert_to_list_np(self):
test_list = [[1, 2, 3], [4, 5, 6]]

# Test with an already converted list
self.assertListEqual(WhisperTokenizer._convert_to_list(test_list), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(test_list), test_list)

# Test with a numpy array
np_array = np.array(test_list)
self.assertListEqual(WhisperTokenizer._convert_to_list(np_array), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(np_array), test_list)

@require_tf
def test_convert_to_list_tf(self):
import tensorflow as tf

test_list = [[1, 2, 3], [4, 5, 6]]
tf_tensor = tf.constant(test_list)
self.assertListEqual(WhisperTokenizer._convert_to_list(tf_tensor), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(tf_tensor), test_list)

@require_flax
def test_convert_to_list_jax(self):
import jax.numpy as jnp

test_list = [[1, 2, 3], [4, 5, 6]]
jax_array = jnp.array(test_list)
self.assertListEqual(WhisperTokenizer._convert_to_list(jax_array), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(jax_array), test_list)

@require_torch
def test_convert_to_list_pt(self):
import torch

test_list = [[1, 2, 3], [4, 5, 6]]
torch_tensor = torch.tensor(test_list)
self.assertListEqual(WhisperTokenizer._convert_to_list(torch_tensor), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(torch_tensor), test_list)