-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is the BOS token id of 128000 **hardcoded** into the llama 3.2 tokenizer? #33998
Comments
Hey! Your print(tokenizer._tokenizer.post_processor)
Sequence(processors=[ByteLevel(add_prefix_space=True, trim_offsets=False, use_regex=True), TemplateProcessing(single=[SpecialToken(id="<|begin_of_text|>", type_id=0), Sequence(id=A, type_id=0)], pair=[SpecialToken(id="<|begin_of_text|>", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="<|begin_of_text|>", type_id=1), Sequence(id=B, type_id=1)], special_tokens={"<|begin_of_text|>":SpecialToken(id="<|begin_of_text|>", ids=[128000], tokens=["<|begin_of_text|>"])})]) This is because the Otherwise, we can have an extra layer of logic in TLDR; your fix: tokenizer._tokenizer.post_processor = Sequence( [
ByteLevel(add_prefix_space=True, trim_offsets=False, use_regex=True),
TemplateProcessing(
single="<|begin_of_text|> $0",
pair="<|begin_of_text|> $A <|begin_of_text|> $B:1",
special_tokens=[("<|begin_of_text|>", 0)],
)
]) this outpuded: tensor([[ 0, 1704, 802]]) |
(FYI @itazap ) |
Thanks @ArthurZucker , this fixed the issue. And yes, I used |
Okay, in that case I think we ought to update the post_processor when we train from new! |
Actually you can pass special tokens here: transformers/src/transformers/tokenization_utils_fast.py Lines 739 to 741 in c185608
which already updates them. I am working on improving the api to make it simpler! |
Does this affect the postprocessor? After training the tokenizer, the special token ids are correctly updated to numbers below the vocab size ( {
"added_tokens_decoder": {
"0": {
"content": "<|begin_of_text|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "<|end_of_text|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "<|reserved_special_token_0|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}, |
Hi @ArthurZucker , is there a way to add new tokens to a BPE tokenizer like Llama 3's while keeping the original vocabulary? I wanted to extend the vocabulary of Llama 3.2 1B's tokenizer and then continue pretraining it on a new (Amharic language) corpus. UPDATE: I got it to work. I had to edit the |
Hey! you can quite simply add the tokens, AS LONG AS you use the |
Cool, Thanks! |
If I want to output tensor([[1704, 802]]) |
You can set |
@ArthurZucker setting |
Ah right, this is because we use the from tokenizers.processors import *
tokenizer._tokenizer.post_processor = ByteLevel() |
System Info
transformers
version: 4.45.1Who can help?
@ArthurZucker @itazap
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I trained the llama 3.2 tokenizer using an Amharic language corpus and a vocab size of
28k
, but when I use it to tokenize text, the first token id is still128000
when it should have been the new tokenizer's BOS token id of0
.And here's a tokenization of an example text. As can be seen, the first token id is
128000
when it should have been0
.Output:
Expected behavior
The first token id of the tokenized text should be the new tokenizer's BOS token id of
0
instead of the original llama 3.2 tokenizer's BOS token id of128000
. The vocab size is28000
and the number128000
should not appear anywhere in theinput_ids
list.This is causing index out of range errors when indexing the embedding matrix of a newly initialized model.
The text was updated successfully, but these errors were encountered: