Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wav2vec2] Fixed tokenization mistakes while adding single-char tokens to tokenizer #11538

Merged
merged 6 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions src/transformers/models/wav2vec2/tokenization_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import numpy as np

from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils import PreTrainedTokenizer, _insert_one_token_to_ordered_list
from ...tokenization_utils_base import AddedToken, BatchEncoding
from ...utils import logging


Expand Down Expand Up @@ -277,6 +277,61 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =

return (vocab_file,)

def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
it with indices starting from length of the current vocabulary.

Args:
new_tokens (:obj:`List[str]`or :obj:`List[tokenizers.AddedToken]`):
Token(s) to add in vocabulary. A token is only added if it's not already in the vocabulary (tested by
checking if the tokenizer assign the index of the ``unk_token`` to them).
special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the tokens should be added as special tokens.

Returns:
:obj:`int`: The number of tokens actually added to the vocabulary.

Examples::

# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('facebook/wav2vec2-base-960h')
model = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-base-960h')

num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
print('We have added', num_added_toks, 'tokens')
# Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))
"""
new_tokens = [str(tok) for tok in new_tokens]

tokens_to_add = []
for token in new_tokens:
assert isinstance(token, str)
if not special_tokens and hasattr(self, "do_lower_case") and self.do_lower_case:
token = token.lower()
if (
token != self.unk_token
and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
and token not in tokens_to_add
):
tokens_to_add.append(token)
if self.verbose:
logger.info(f"Adding {token} to the vocabulary")

added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.added_tokens_decoder.update(added_tok_decoder)

# Make sure we don't split on any special tokens (even they were already in the vocab before)
for token in tokens_to_add:
if len(token) > 1:
self._additional_special_tokens.append(AddedToken(token))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is important to make sure tokens longer then 1 are on the one hand not split and on the other hand don't strip the left and right space

_insert_one_token_to_ordered_list(self.unique_no_split_tokens, token)

return len(tokens_to_add)


class Wav2Vec2Tokenizer(PreTrainedTokenizer):
"""
Expand Down
84 changes: 84 additions & 0 deletions tests/test_tokenization_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,38 @@ def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)

def test_tokenizer_add_token_chars(self):
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")

# check adding a single token
tokenizer.add_tokens("x")
token_ids = tokenizer("C x A").input_ids
self.assertEqual(token_ids, [19, 4, 32, 4, 7])

tokenizer.add_tokens(["a", "b", "c"])
token_ids = tokenizer("C a A c").input_ids
self.assertEqual(token_ids, [19, 4, 33, 4, 7, 4, 35])

tokenizer.add_tokens(["a", "b", "c"])
token_ids = tokenizer("CaA c").input_ids
self.assertEqual(token_ids, [19, 33, 7, 4, 35])

def test_tokenizer_add_token_words(self):
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")

# check adding a single token
tokenizer.add_tokens("xxx")
token_ids = tokenizer("C xxx A B").input_ids
self.assertEqual(token_ids, [19, 4, 32, 4, 7, 4, 24])

tokenizer.add_tokens(["aaa", "bbb", "ccc"])
token_ids = tokenizer("C aaa A ccc B B").input_ids
self.assertEqual(token_ids, [19, 4, 33, 4, 7, 4, 35, 4, 24, 4, 24])

tokenizer.add_tokens(["aaa", "bbb", "ccc"])
token_ids = tokenizer("CaaaA ccc B B").input_ids
self.assertEqual(token_ids, [19, 33, 7, 4, 35, 4, 24, 4, 24])

def test_tokenizer_decode(self):
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")

Expand Down Expand Up @@ -470,3 +502,55 @@ def test_special_characters_in_vocab(self):
def test_pretrained_model_lists(self):
# Wav2Vec2Model has no max model length => no testing
pass

# overwrite from test_tokenization_common
def test_add_tokens_tokenizer(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wav2Vec2 should split on newly added tokens that are longer than 1 char:
Doing:

tok.add_tokens("aaa")

Should then encode / decode to

tok.decode(tok("hello this is to encode aaa").input_ids) == "hello this is to encode aaa"

and not

tok.decode(tok("hello this is to encode aaa").input_ids) == "hello this is to encodeaaa"

tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
vocab_size = tokenizer.vocab_size
all_size = len(tokenizer)

self.assertNotEqual(vocab_size, 0)

# We usually have added tokens from the start in tests because our vocab fixtures are
# smaller than the original vocabs - let's not assert this
# self.assertEqual(vocab_size, all_size)

new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
added_toks = tokenizer.add_tokens(new_toks)
vocab_size_2 = tokenizer.vocab_size
all_size_2 = len(tokenizer)

self.assertNotEqual(vocab_size_2, 0)
self.assertEqual(vocab_size, vocab_size_2)
self.assertEqual(added_toks, len(new_toks))
self.assertEqual(all_size_2, all_size + len(new_toks))

tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l", add_special_tokens=False)

self.assertGreaterEqual(len(tokens), 4)
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-3], tokenizer.vocab_size - 1)

new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"}
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
vocab_size_3 = tokenizer.vocab_size
all_size_3 = len(tokenizer)

self.assertNotEqual(vocab_size_3, 0)
self.assertEqual(vocab_size, vocab_size_3)
self.assertEqual(added_toks_2, len(new_toks_2))
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))

tokens = tokenizer.encode(
">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", add_special_tokens=False
)

self.assertGreaterEqual(len(tokens), 6)
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
self.assertGreater(tokens[0], tokens[1])
self.assertGreater(tokens[-3], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-3], tokens[-4])
self.assertEqual(tokens[0], tokenizer.eos_token_id)
self.assertEqual(tokens[-3], tokenizer.pad_token_id)