Skip to content

Commit

Permalink
Allow add_tokens for ESM (#28535)
Browse files Browse the repository at this point in the history
* Allow non-special tokens to be added

* Add test, fix token adding code

* Revert changes to id_to_token and token_to_id

* Update the ESM tokenizer to be a bit more standardized

* Update src/transformers/models/esm/tokenization_esm.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
Rocketknight1 and ArthurZucker authored Jan 19, 2024
1 parent 5b7f4bc commit d157815
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
15 changes: 5 additions & 10 deletions src/transformers/models/esm/tokenization_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
# limitations under the License.
"""Tokenization classes for ESM."""
import os
from typing import List, Optional, Union
from typing import List, Optional

from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import AddedToken
from ...utils import logging


Expand Down Expand Up @@ -91,11 +90,10 @@ def _convert_token_to_id(self, token: str) -> int:
def _tokenize(self, text, **kwargs):
return text.split()

def get_vocab_size(self, with_added_tokens=False):
return len(self._id_to_token)

def get_vocab(self):
return {token: i for i, token in enumerate(self.all_tokens)}
base_vocab = self._token_to_id.copy()
base_vocab.update(self.added_tokens_encoder)
return base_vocab

def token_to_id(self, token: str) -> int:
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
Expand Down Expand Up @@ -156,7 +154,4 @@ def save_vocabulary(self, save_directory, filename_prefix):

@property
def vocab_size(self) -> int:
return self.get_vocab_size(with_added_tokens=False)

def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
return super()._add_tokens(new_tokens, special_tokens=True)
return len(self.all_tokens)
22 changes: 22 additions & 0 deletions tests/models/esm/test_tokenization_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,25 @@ def test_tokenize_special_tokens(self):
self.assertEqual(len(token_2), 1)
self.assertEqual(token_1[0], SPECIAL_TOKEN_1)
self.assertEqual(token_2[0], SPECIAL_TOKEN_2)

def test_add_tokens(self):
tokenizer = self.tokenizer_class(self.vocab_file)

vocab_size = len(tokenizer)
self.assertEqual(tokenizer.add_tokens(""), 0)
self.assertEqual(tokenizer.add_tokens("testoken"), 1)
self.assertEqual(tokenizer.add_tokens(["testoken1", "testtoken2"]), 2)
self.assertEqual(len(tokenizer), vocab_size + 3)

self.assertEqual(tokenizer.add_special_tokens({}), 0)
self.assertEqual(tokenizer.add_special_tokens({"bos_token": "[BOS]", "eos_token": "[EOS]"}), 2)
self.assertRaises(AssertionError, tokenizer.add_special_tokens, {"additional_special_tokens": "<testtoken1>"})
self.assertEqual(tokenizer.add_special_tokens({"additional_special_tokens": ["<testtoken2>"]}), 1)
self.assertEqual(
tokenizer.add_special_tokens({"additional_special_tokens": ["<testtoken3>", "<testtoken4>"]}), 2
)
self.assertIn("<testtoken3>", tokenizer.special_tokens_map["additional_special_tokens"])
self.assertIsInstance(tokenizer.special_tokens_map["additional_special_tokens"], list)
self.assertGreaterEqual(len(tokenizer.special_tokens_map["additional_special_tokens"]), 2)

self.assertEqual(len(tokenizer), vocab_size + 8)

0 comments on commit d157815

Please sign in to comment.