Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is the BOS token id of 128000 **hardcoded** into the llama 3.2 tokenizer? #33998

Closed
4 tasks
rasyosef opened this issue Oct 7, 2024 · 13 comments · May be fixed by #34246
Closed
4 tasks

Is the BOS token id of 128000 **hardcoded** into the llama 3.2 tokenizer? #33998

rasyosef opened this issue Oct 7, 2024 · 13 comments · May be fixed by #34246
Labels
bug Core: Tokenization Internals of the library; Tokenization.

Comments

@rasyosef
Copy link

rasyosef commented Oct 7, 2024

System Info

  • transformers version: 4.45.1
  • Platform: Linux-5.15.154+-x86_64-with-glibc2.31
  • Python version: 3.10.13
  • Huggingface_hub version: 0.23.2
  • Safetensors version: 0.4.3
  • Accelerate version: 0.30.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.2+cpu (False)
  • Tensorflow version (GPU?): 2.15.0 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.8.4 (cpu)
  • Jax version: 0.4.28
  • JaxLib version: 0.4.28
  • Using distributed or parallel set-up in script?:

Who can help?

@ArthurZucker @itazap

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I trained the llama 3.2 tokenizer using an Amharic language corpus and a vocab size of 28k, but when I use it to tokenize text, the first token id is still 128000 when it should have been the new tokenizer's BOS token id of 0.

And here's a tokenization of an example text. As can be seen, the first token id is 128000 when it should have been 0.

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("rasyosef/llama-3.2-amharic-tokenizer-28k")

text = "ሁሉም ነገር"
inputs = tokenizer(text, return_tensors="pt")
print(inputs["input_ids"])

Output:

tensor([[128000,   1704,    802]])

Expected behavior

The first token id of the tokenized text should be the new tokenizer's BOS token id of 0 instead of the original llama 3.2 tokenizer's BOS token id of 128000. The vocab size is 28000 and the number 128000 should not appear anywhere in the input_ids list.

This is causing index out of range errors when indexing the embedding matrix of a newly initialized model.

@rasyosef rasyosef added the bug label Oct 7, 2024
@ArthurZucker
Copy link
Collaborator

Hey! Your post_processor is the following:

print(tokenizer._tokenizer.post_processor)
Sequence(processors=[ByteLevel(add_prefix_space=True, trim_offsets=False, use_regex=True), TemplateProcessing(single=[SpecialToken(id="<|begin_of_text|>", type_id=0), Sequence(id=A, type_id=0)], pair=[SpecialToken(id="<|begin_of_text|>", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="<|begin_of_text|>", type_id=1), Sequence(id=B, type_id=1)], special_tokens={"<|begin_of_text|>":SpecialToken(id="<|begin_of_text|>", ids=[128000], tokens=["<|begin_of_text|>"])})])

This is because the TemplateProcessor was not changed.
This is probably an issue with train_new_from_iterator is that what you used?

Otherwise, we can have an extra layer of logic in transformers that modifies the bos_token when loading if there is a post processor.

TLDR; your fix:

tokenizer._tokenizer.post_processor = Sequence( [ 
ByteLevel(add_prefix_space=True, trim_offsets=False, use_regex=True), 
TemplateProcessing(
    single="<|begin_of_text|> $0",
    pair="<|begin_of_text|> $A <|begin_of_text|> $B:1",
    special_tokens=[("<|begin_of_text|>", 0)],
)
])

this outpuded:

tensor([[   0, 1704,  802]])

@ArthurZucker ArthurZucker added the Core: Tokenization Internals of the library; Tokenization. label Oct 7, 2024
@ArthurZucker
Copy link
Collaborator

(FYI @itazap )

@rasyosef
Copy link
Author

rasyosef commented Oct 7, 2024

Thanks @ArthurZucker , this fixed the issue. And yes, I used train_new_from_iterator.

@rasyosef rasyosef closed this as completed Oct 7, 2024
@ArthurZucker
Copy link
Collaborator

Okay, in that case I think we ought to update the post_processor when we train from new!

@ArthurZucker
Copy link
Collaborator

Actually you can pass special tokens here:

special_tokens_map (`Dict[str, str]`, *optional*):
If you want to rename some of the special tokens this tokenizer uses, pass along a mapping old special
token name to new special token name in this argument.

which already updates them. I am working on improving the api to make it simpler!

@rasyosef
Copy link
Author

Does this affect the postprocessor? After training the tokenizer, the special token ids are correctly updated to numbers below the vocab size (0 for bos and 1 for eos). As you pointed out earlier, it was the bos token id in the postprocessor that was not updated and remained at 128000. Here are the first few lines of tokenizer_config.json

{
  "added_tokens_decoder": {
    "0": {
      "content": "<|begin_of_text|>",
      "lstrip": false,
      "normalized": false,
      "rstrip": false,
      "single_word": false,
      "special": true
    },
    "1": {
      "content": "<|end_of_text|>",
      "lstrip": false,
      "normalized": false,
      "rstrip": false,
      "single_word": false,
      "special": true
    },
    "2": {
      "content": "<|reserved_special_token_0|>",
      "lstrip": false,
      "normalized": false,
      "rstrip": false,
      "single_word": false,
      "special": true
    },

@rasyosef
Copy link
Author

rasyosef commented Oct 21, 2024

Hi @ArthurZucker , is there a way to add new tokens to a BPE tokenizer like Llama 3's while keeping the original vocabulary?

I wanted to extend the vocabulary of Llama 3.2 1B's tokenizer and then continue pretraining it on a new (Amharic language) corpus. tokenizer.add_tokens(['ĠáĭĭáĬĵ', 'ĠáĬ¨áī°áĪĽ']) does not seem to work, possibly because the list of merges need to be update as well.

UPDATE: I got it to work. I had to edit the tokenizer.json file of the Llama 3.2 tokenizer and add the new (Amharic) tokens and list of merges, and IT WORKS!

@ArthurZucker
Copy link
Collaborator

Hey! you can quite simply add the tokens, AS LONG AS you use the byte level normalizer and not the ``pre_tokenizer`!
It was added in: huggingface/tokenizers#1555

@rasyosef
Copy link
Author

Cool, Thanks!

@csHuangfdu
Copy link

If I want to output tensor([[1704, 802]])
that is, I don't want to output bos_token
what should I do?

@ArthurZucker
Copy link
Collaborator

You can set add_bos_token=False and from_slow=False

@rohithkrn
Copy link

@ArthurZucker setting add_bos_token=False and from_slow=False still outputs bos_token

@ArthurZucker
Copy link
Collaborator

Ah right, this is because we use the PreTrainedTokenizerFast API. Let me help:

from tokenizers.processors import *
tokenizer._tokenizer.post_processor = ByteLevel()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Core: Tokenization Internals of the library; Tokenization.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants