Skip to content

Commit

Permalink
add a test checking the format of convert_tokens_to_string's output (
Browse files Browse the repository at this point in the history
…#16540)

* add new tests

* add comment to overridden tests
  • Loading branch information
SaulLu authored Apr 4, 2022
1 parent 24a85cc commit be9474b
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tests/byt5/test_tokenization_byt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,14 @@ def test_pretokenized_inputs(self):
# tests all ids in vocab => vocab doesn't exist so unnecessary to test
def test_conversion_reversible(self):
pass

def test_convert_tokens_to_string_format(self):
# The default common tokenizer tests uses invalid tokens for ByT5 that can only accept one-character strings
# and special added tokens as tokens
tokenizers = self.get_tokenizers(fast=True, do_lower_case=True)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
tokens = ["t", "h", "i", "s", " ", "i", "s", " ", "a", " ", "t", "e", "x", "t", "</s>"]
string = tokenizer.convert_tokens_to_string(tokens)

self.assertIsInstance(string, str)
11 changes: 11 additions & 0 deletions tests/perceiver/test_tokenization_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,14 @@ def test_pretokenized_inputs(self):
# tests all ids in vocab => vocab doesn't exist so unnecessary to test
def test_conversion_reversible(self):
pass

def test_convert_tokens_to_string_format(self):
# The default common tokenizer tests uses invalid tokens for Perceiver that can only accept one-character
# strings and special added tokens as tokens
tokenizers = self.get_tokenizers(fast=True, do_lower_case=True)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
tokens = ["[CLS]", "t", "h", "i", "s", " ", "i", "s", " ", "a", " ", "t", "e", "s", "t", "[SEP]"]
string = tokenizer.convert_tokens_to_string(tokens)

self.assertIsInstance(string, str)
9 changes: 9 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3713,6 +3713,15 @@ def test_saving_tokenizer_trainer(self):
trainer.save_model(os.path.join(tmp_dir, "checkpoint"))
self.assertIn("tokenizer.json", os.listdir(os.path.join(tmp_dir, "checkpoint")))

def test_convert_tokens_to_string_format(self):
tokenizers = self.get_tokenizers(fast=True, do_lower_case=True)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
tokens = ["this", "is", "a", "test"]
string = tokenizer.convert_tokens_to_string(tokens)

self.assertIsInstance(string, str)

def test_save_slow_from_fast_and_reload_fast(self):
if not self.test_slow_tokenizer or not self.test_rust_tokenizer:
# we need both slow and fast versions
Expand Down
11 changes: 11 additions & 0 deletions tests/wav2vec2/test_tokenization_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,3 +753,14 @@ def test_tf_encode_plus_sent_to_model(self):
@unittest.skip("The tokenizer shouldn't be used to encode input IDs (except for labels), only to decode.")
def test_torch_encode_plus_sent_to_model(self):
pass

def test_convert_tokens_to_string_format(self):
# The default common tokenizer tests assumes that the output of `convert_tokens_to_string` is a string which
# is not the case for Wav2vec2.
tokenizers = self.get_tokenizers(fast=True, do_lower_case=True)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
tokens = ["T", "H", "I", "S", "|", "I", "S", "|", "A", "|", "T", "E", "X", "T"]
output = tokenizer.convert_tokens_to_string(tokens)

self.assertIsInstance(output["text"], str)
11 changes: 11 additions & 0 deletions tests/wav2vec2_phoneme/test_tokenization_wav2vec2_phoneme.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,3 +398,14 @@ def test_tf_encode_plus_sent_to_model(self):
@unittest.skip("The tokenizer shouldn't be used to encode input IDs (except for labels), only to decode.")
def test_torch_encode_plus_sent_to_model(self):
pass

def test_convert_tokens_to_string_format(self):
# The default common tokenizer tests assumes that the output of `convert_tokens_to_string` is a string which
# is not the case for Wav2Vec2PhonemeCTCTokenizer.
tokenizers = self.get_tokenizers(fast=True, do_lower_case=True)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
tokens = ["ð", "ɪ", "s", "ɪ", "z", "ɐ", "t", "ɛ", "k", "s", "t"]
output = tokenizer.convert_tokens_to_string(tokens)

self.assertIsInstance(output["text"], str)

0 comments on commit be9474b

Please sign in to comment.