Skip to content

Commit

Permalink
added assertions
Browse files Browse the repository at this point in the history
  • Loading branch information
Ita Zaporozhets authored and ArthurZucker committed May 14, 2024
1 parent 357dfc7 commit 6ed07ac
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4196,13 +4196,22 @@ def test_split_special_tokens(self):
self.assertEqual(cr_output, r_output)
self.assertTrue(special_token not in p_output)

p_output_explicit = tokenizer_p.tokenize(f"Hey this is a {special_token} token",
split_special_tokens=False)
p_output_explicit = tokenizer_p.tokenize(f"Hey this is a {special_token} token", split_special_tokens=False)
r_output_explicit = tokenizer_r.tokenize(f"Hey this is a {special_token} token", split_special_tokens=False)
cr_output_explicit = tokenizer_cr.tokenize(f"Hey this is a {special_token} token", split_special_tokens=False)

self.assertTrue(special_token in p_output_explicit)
self.assertEqual(p_output_explicit, r_output_explicit)
self.assertEqual(cr_output_explicit, r_output_explicit)

special_token_id = tokenizer_r.encode(special_token, add_special_tokens=False)[0]
p_special_token_id = tokenizer_p.encode(special_token, add_special_tokens=False)[0]
p_output = tokenizer_p(f"Hey this is a {special_token} token")
self.assertTrue(special_token_id not in p_output)
r_output = tokenizer_r(f"Hey this is a {special_token} token")
cr_output = tokenizer_cr(f"Hey this is a {special_token} token")

self.assertTrue(p_special_token_id not in p_output)
self.assertEqual(p_output, r_output)
self.assertEqual(cr_output, r_output)

tmpdirname = tempfile.mkdtemp()
tokenizer_p.save_pretrained(tmpdirname)
Expand All @@ -4211,8 +4220,7 @@ def test_split_special_tokens(self):
output_reloaded = fast_from_saved.tokenize(f"Hey this is a {special_token} token")
self.assertTrue(special_token not in output_reloaded)

output_explicit_reloaded = fast_from_saved.tokenize(f"Hey this is a {special_token} token",
split_special_tokens=False)
output_explicit_reloaded = fast_from_saved.tokenize(f"Hey this is a {special_token} token", split_special_tokens=False)
self.assertTrue(special_token in output_explicit_reloaded)

def test_added_tokens_serialization(self):
Expand Down

0 comments on commit 6ed07ac

Please sign in to comment.