Skip to content

Commit

Permalink
🚨🚨[Whisper Tok] Update integration test (#29368)
Browse files Browse the repository at this point in the history
* [Whisper Tok] Update integration test

* make style
  • Loading branch information
sanchit-gandhi committed Mar 1, 2024
1 parent e7b9837 commit 0a0a279
Showing 1 changed file with 8 additions and 30 deletions.
38 changes: 8 additions & 30 deletions tests/models/whisper/test_tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
from transformers.testing_utils import require_jinja, slow
from transformers.testing_utils import slow

from ...test_tokenization_common import TokenizerTesterMixin

Expand Down Expand Up @@ -67,26 +67,26 @@ def test_full_tokenizer(self):
tokenizer = WhisperTokenizer.from_pretrained(self.tmpdirname)

tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, ["This", "Ġis", "Ġa", "Ġ", "test"])
self.assertListEqual(tokens, ["This", "Ġis", "Ġa", "Ġtest"])

self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens),
[5723, 307, 257, 220, 31636],
[5723, 307, 257, 1500],
)

tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(
tokens,
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġ", "this", "Ġis", "Ġfals", "é", "."], # fmt: skip
) # fmt: skip
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġthis", "Ġis", "Ġfals", "é", "."], # fmt: skip
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(ids, [40, 390, 4232, 294, 1722, 25743, 11, 293, 220, 11176, 307, 16720, 526, 13])
self.assertListEqual(ids, [40, 390, 4232, 294, 1722, 25743, 11, 293, 341, 307, 16720, 526, 13])

back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(
back_tokens,
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġ", "this", "Ġis", "Ġfals", "é", "."], # fmt: skip
) # fmt: skip
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġthis", "Ġis", "Ġfals", "é", "."], # fmt: skip
)

def test_tokenizer_slow_store_full_signature(self):
pass
Expand Down Expand Up @@ -499,25 +499,3 @@ def test_offset_decoding(self):

output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(output, [])

@require_jinja
def test_tokenization_for_chat(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
# This is in English, but it's just here to make sure the chat control tokens are being added properly
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
]
tokenized_chats = [multilingual_tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
expected_tokens = [
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257],
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257, 37717, 220, 1353, 1677, 291, 13, 50257],
[37717, 220, 1353, 1677, 291, 13, 50257, 15947, 0, 50257],
]
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)

0 comments on commit 0a0a279

Please sign in to comment.