-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Weird Tokenization when Training New Tokenizer from Llama 2 Tokenizer using train_new_from_iterator
#27900
Comments
Ahhh I'll have a look that looks a bit nasty indeed |
Hi @ArthurZucker , any updates on this? Thank you! |
Hey, I can't reproduce this yet. I don't have your local dataset, and I don't have the loading script so def python_generator():
# Load local files for code_search_net/python
# https://huggingface.co/datasets/code_search_net
dataset = load_dataset("code_search_net/python.py", "python")
dataset = dataset["train"]
for start_idx in range(0, len(dataset), 1000):
samples = dataset[start_idx: start_idx + 1000]
yield samples["whole_func_string"] fails with |
I cannot help you without a proper reproducer |
One thing that is certain is that Bytefallback does not seem to be activated (properly) because the bytes should be part of the vocab, the trainer should have a logic to handle that which it does not at the moment |
I've updated the script above. Hopefully it works now! |
Same here! There are tokens in the vocabulary that consist of some joined words, like |
What did you train your tokenizer on? |
@phoongkhangzhie I had to update your script it does not work out of the box, |
@ArthurZucker on batches of strings. It seems it's not splitting words |
I think a quick fix would be to disable the normalizer and use a metaspace pre-tokenizer instead. from tokenizers import pre_tokenizers, normalizers
from transformers import AutoTokenizer
old_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
old_tokenizer._tokenizer.normalizer = normalizers.Sequence([])
old_tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace("▁", True, prepend_scheme = "first") |
It works, the vocabulary is correctly generated now. However, it does not pretokenize punctuation: (Pdb) old_tokenizer.convert_ids_to_tokens(old_tokenizer("This is a test.")["input_ids"])
['<s>', '▁This', '▁is', '▁a', '▁test', '.']
(Pdb) new_tokenizer.convert_ids_to_tokens(new_tokenizer("This is a test.")["input_ids"])
['<s>', '▁Th', 'is', '▁is', '▁a', '▁tes', 't.'] |
That's because it is probably missing a replace normalizer. so something like this: from tokenizers import pre_tokenizers, normalizers
from transformers import AutoTokenizer
old_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
old_tokenizer._tokenizer.normalizer = normalizers.Sequence([normalizers.Strip(left=False, right=True), normalizers.Replace(Regex(" {2,}"), "▁")])
old_tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace("▁", True, prepend_scheme = "first") (make sure you don't use "_" but "▁" |
I've added the noramlizer as you said. I solves the final dot issue. However, inner punctuation is not tokenized. There are tokens like old_tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
old_tokenizer._tokenizer.normalizer = normalizers.Sequence([normalizers.Strip(left=False, right=True), normalizers.Replace(tokenizers.Regex(" {2,}"), "▁")])
old_tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Sequence([pre_tokenizers.Punctuation(), pre_tokenizers.Metaspace(prepend_scheme="first")]) |
Thank you @ArthurZucker and @anderleich for your inputs. I think there are still issues with the tokenizer even after the various fixes.
With the above fix, the outputs are:
This fix prepends all whitespace characters with
With the above fix, the outputs are:
This fix collapses all the whitespace characters into a single
And with this fix, the outputs are:
While this tokenization might be better than the above one, I think it is too aggressive with the splitting of the punctuation. Like the above fixes, the newline character Ideally, the outputs should be like this (similar to the GPT2 tokenization):
Will there be any other fixes for this? |
If you want to keep the white space, Regarding the merges, it might be the frequency of the So the last issue is probably the bytefallback. |
@ArthurZucker are there any plans to add all those fixes to the |
There are plans to add these fixes to the LlamaTokenizer as a whole (specifically the pretokenizer vs normalizer) here #26678. The bytefallback thing needs to be adde to |
System Info
transformers
version: 4.35.2Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
The function
train_new_from_iterator
works as expected when training a new tokenizer from a gpt2 tokenizer as demonstrated in the example, but does not work for training a new tokenizer from a Llama-2 tokenizer.With the code snippet above, training a tokenizer from gpt2 gives the output:
However, training Llama-2's tokenizer gives:
The underscores
_
should be prepended at the front of new words, but it seems to be inserted at the back of words or in between words. In fact, it seems like the retrained tokenizer is worse than the original tokenizer on the new data.The text was updated successfully, but these errors were encountered: