Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with Adding New Tokens to ESM2 Model Tokenizer #28387

Closed
mahdip72 opened this issue Jan 8, 2024 · 16 comments · Fixed by #28535
Closed

Issue with Adding New Tokens to ESM2 Model Tokenizer #28387

mahdip72 opened this issue Jan 8, 2024 · 16 comments · Fixed by #28535

Comments

@mahdip72
Copy link

mahdip72 commented Jan 8, 2024

Hello

I am encountering an issue while working with the ESM2 models (facebook/esm2_t6_8M_UR50D). Specifically, when I try to add new tokens to the tokenizer, they are automatically classified as special tokens, even though I am specifying special_tokens=False.

Here is the code snippet I am using:

model_checkpoint = "facebook/esm2_t6_8M_UR50D"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
num_added_toks = tokenizer.add_tokens(['J'], special_tokens=False)
print("We have added", num_added_toks, "tokens")
model.resize_token_embeddings(len(tokenizer))

After executing this code, the new token ('J') is added as a special token, which is not the intended behavior. This behavior is different compared to when I use similar code with BERT models, where new tokens are added as expected without being automatically marked as special.

The vocab output is below:

<bound method EsmTokenizer.get_vocab of EsmTokenizer(name_or_path=facebook/esm2_t6_8M_UR50D’, vocab_size=33, model_max_length=1024, is_fast=False, padding_side=right’, truncation_side=right’, special_tokens={‘eos_token’: ‘’, ‘unk_token’: ‘’, ‘pad_token’: ‘’, ‘cls_token’: ‘’, ‘mask_token’: ‘’, ‘additional_special_tokens’: [‘J’]}, clean_up_tokenization_spaces=True), added_tokens_decoder={
0: AddedToken(“”, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
1: AddedToken(“”, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
2: AddedToken(“”, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
3: AddedToken(“”, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
32: AddedToken(“”, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
33: AddedToken(“J”, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}>

My main problem is that I noticed the length of the tokenizer does not change after adding the new token and therefore the above code does not extend the embeddings layer as expected.

I'm seeking guidance or a workaround for this issue. Is this a known issue with the ESM2 tokenizer, or am I missing something in my implementation?

Any help or insight into this matter would be greatly appreciated.

Thank you!

@Narsil
Copy link
Contributor

Narsil commented Jan 8, 2024

Seems like a bug with ESMTokenizer, (which doesn't use this library).

@ArthurZucker for insights or the more relevant people ?

@Narsil Narsil transferred this issue from huggingface/tokenizers Jan 8, 2024
@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jan 8, 2024

Hey, I cannot reproduce this:

In [23]: model_checkpoint = "facebook/esm2_t6_8M_UR50D"
    ...: tokenizer_2 = AutoTokenizer.from_pretrained(model_checkpoint)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
tokenizer_config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95.0/95.0 [00:00<00:00, 135kB/s]
vocab.txt: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 93.0/93.0 [00:00<00:00, 247kB/s]
special_tokens_map.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [00:00<00:00, 416kB/s]

In [24]: tokenizer_2
Out[24]: 
EsmTokenizer(name_or_path='facebook/esm2_t6_8M_UR50D', vocab_size=33, model_max_length=1024, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<eos>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'cls_token': '<cls>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
        0: AddedToken("<cls>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        2: AddedToken("<eos>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        32: AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
>>> tokenizer_2.add_tokens(["J"]) 
EsmTokenizer(name_or_path='facebook/esm2_t6_8M_UR50D', vocab_size=33, model_max_length=1024, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<eos>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'cls_token': '<cls>', 'mask_token': '<mask>', 'additional_special_tokens': ['J']}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
        0: AddedToken("<cls>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        2: AddedToken("<eos>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        32: AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        33: AddedToken("J", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
In [29]: tokenizer_2.get_vocab()
Out[29]: 
{'<cls>': 0,
 '<pad>': 1,
 '<eos>': 2,
 '<unk>': 3,
 'L': 4,
 'A': 5,
 'G': 6,
 'V': 7,
 'S': 8,
 'E': 9,
 'R': 10,
 'T': 11,
 'I': 12,
 'D': 13,
 'P': 14,
 'K': 15,
 'Q': 16,
 'N': 17,
 'F': 18,
 'Y': 19,
 'M': 20,
 'H': 21,
 'W': 22,
 'C': 23,
 'X': 24,
 'B': 25,
 'U': 26,
 'Z': 27,
 'O': 28,
 '.': 29,
 '-': 30,
 '<null_1>': 31,
 '<mask>': 32}

@mahdip72
Copy link
Author

mahdip72 commented Jan 8, 2024

My main problem is that I noticed the length of the tokenizer does not change after adding the new token and therefore the above code does not extend the embeddings layer as expected.

@ArthurZucker My problem is not with being a special token. When I am adding new tokens the vocab size does not change (33). Could you help me understand how to correctly increase the embedding size of the model?

Does it make sense if I define it manually?

model_checkpoint = "facebook/esm2_t6_8M_UR50D"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
num_added_toks = tokenizer.add_tokens(['J'])
model.resize_token_embeddings(33 + num_added_toks)

@ArthurZucker
Copy link
Collaborator

If the token is already part of the vocab, it is expected that the vocab size will not change

@mahdip72
Copy link
Author

mahdip72 commented Jan 10, 2024

@ArthurZucker I am adding completely new tokens. I see them being added to the tokenizer. But the vocab size doesn't changed despite the fact that the new indexes are being set as the additional_special_tokens_ids.
I bypassed the issue using the following line:

model.resize_token_embeddings(max(tokenizer.additional_special_tokens_ids))

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jan 11, 2024

The length of the vocab is different from the max if you have holes in the vocab. This ESMTokenizer uses the length as number of tokens rather than the max!
Nice fix and not sure we should change it no?

@mahdip72
Copy link
Author

mahdip72 commented Jan 11, 2024

@ArthurZucker @Narsil I fixed my problem, but others using ESM models might still have trouble. These models are very important for protein research now. The way the tokenizer counts words can confuse people when they try to make the model learn new tokens. This is different from the usual instruction of extending embedding layer such as llama 2 and could cause errors. Clearer steps in documentation or a fix in the tokenizer might help researchers.

@ArthurZucker
Copy link
Collaborator

cc @Rocketknight1 we might want to update that? WDYT?
@mahdip72 would you like to open a pr for doc fixes?

@Rocketknight1
Copy link
Member

Hi all, I investigated the issue. There is indeed specific code in the ESM tokenizer that causes all new added tokens to be counted as 'special' tokens. I suspect the reason for this was that the authors felt the token list for proteins was constant (since it was just the list of amino acids), and therefore any new token had to be outside the normal vocabulary.

In your case @mahdip72, I'm guessing you want to add either nonstandard amino acids or tokens like J that represent "leucine OR isoleucine", correct? This is a valid use-case for ESM, and I think we should update the tokenizer code to support it. There is the issue of backward compatibility, though, so I see two possible solutions:

1 (More backward compatible):
Update add_tokens so that it keeps special_tokens=True as the default, but lets users manually specify special_tokens=False for cases like this

2 (Matches workflows for other models):
Update add_tokens so that special_tokens=False is the default, like other models. Users will need to manually specify special_tokens=True to add tokens as special tokens. This is probably a better solution, but it may break existing workflows.

I'll see if I can grab a member of the ESM team to comment on this!

@mahdip72
Copy link
Author

In your case @mahdip72, I'm guessing you want to add either nonstandard amino acids or tokens like J that represent "leucine OR isoleucine", correct?

It is correct. My goal is to add new non-separatable tokens like the ESM vocabs to the ESM tokenizer. Also, I have seen lots of folk are adding non-separable 3Di fold seek tokens and/or chemical-related tokens such as SELFIES to the protein language models. As far as I am understand, these tokens are non-separable and constant, similar to amino acids tokens.

@Rocketknight1 Are special tokens constant and inseparable? What is the difference between normal tokens and special tokens in the ESM tokenizer?

@Rocketknight1
Copy link
Member

Hi @mahdip72, the idea of "special tokens" mostly comes from tokenization for language models. In general, special tokens have two main properties:

  • Special tokens can be skipped when decoding using skip_special_tokens = True.
  • Special tokens are never split by the tokenizer.

These traits aren't especially relevant for ESM - in general, people aren't generating sequences with ESM and so tokenizer decoding doesn't apply, and secondly ESM never splits the text it tokenizes because it always converts one character to one token, unlike tokenizers like sentencepiece that are commonly used for natural language.

I think the most sensible solution is to just update add_tokens for ESM so it behaves like other models and adds tokens as "non-special" by default, even though this might affect backward compatibility slightly. What do you think?

@mahdip72
Copy link
Author

@Rocketknight1 I Agree. A general solution similar to other models is more sensible.

@Rocketknight1
Copy link
Member

Hi @mahdip72, I've opened a PR at #28535 that should resolve this. Can you try it out and let me know if it resolves your issue? To install the PR branch, run this command: pip install --upgrade git+https://github.com/huggingface/transformers.git@allow_esm_add_tokens

@tomsercu
Copy link
Contributor

Thanks for looking into this Matt! Agree ignoring special_tokens and force special_tokens=True was problematic there.
Makes sense to make the behavior consistent with the rest of huggingface ecosystem, although I'm a bit nervous about silently breaking backward compatibility for any users who were adding special tokens without explicilty specifying the flag special_tokens=True there

@Rocketknight1
Copy link
Member

Yeah @tomsercu - it was a bit of a concern for us too! I suspect adding tokens was quite niche, though - hopefully the improvement from standardization outweighs the very small number of people who were depending on the old behaviour.

@tomsercu
Copy link
Contributor

Agreed. Thanks for the fix Matt!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants