Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weird Tokenization when Training New Tokenizer from Llama 2 Tokenizer using train_new_from_iterator #27900

Closed
2 of 4 tasks
phoongkhangzhie opened this issue Dec 8, 2023 · 19 comments · Fixed by #26678
Closed
2 of 4 tasks

Comments

@phoongkhangzhie
Copy link

phoongkhangzhie commented Dec 8, 2023

System Info

  • transformers version: 4.35.2
  • Platform: Linux-5.4.0-105-generic-x86_64-with-glibc2.31
  • Python version: 3.10.13
  • Huggingface_hub version: 0.19.4
  • Safetensors version: 0.4.1
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): not installed (NA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import os
import argparse
from datasets import load_dataset
from transformers import (
    AutoTokenizer
)

def python_generator():
    # Load local files for code_search_net/python
    # https://huggingface.co/datasets/code_search_net
    dataset = load_dataset("code_search_net", "python")
    dataset = dataset["train"]
    for start_idx in range(0, len(dataset), 1000):
        samples = dataset[start_idx: start_idx + 1000]
        yield samples["whole_func_string"]

def main(args):
    model_paths = [
        "gpt2",
        "meta-llama/Llama-2-70b-hf",
    ]
    access_token = ""
    for model_path in model_paths:
        print(f"\n\n{model_path}")
        save_dir = (
            f"{model_path}-python-52K_vocab"
        )
        os.makedirs(os.path.join(os.getcwd(), "tokenizers"), exist_ok=True)
        save_path = os.path.join(os.getcwd(), "tokenizers", save_dir)

        old_tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            token=access_token
        )
        assert old_tokenizer.is_fast

        if os.path.exists(save_path):
            new_tokenizer = AutoTokenizer.from_pretrained(save_path)
        else:
            new_tokenizer = old_tokenizer.train_new_from_iterator(
                python_generator(),
                vocab_size=52000
            )
            new_tokenizer.save_pretrained(save_path)

        example_1 = '''
        def add_numbers(a, b):
            """Add the two numbers `a` and `b`."""
            return a + b
        '''
        print(f"\n{example_1}")
        old_tokens = old_tokenizer.tokenize(example_1)
        print(f"old: {old_tokens}")
        new_tokens = new_tokenizer.tokenize(example_1)
        print(f"new: {new_tokens}")

        example_2 = """
        class LinearLayer():
            def __init__(self, input_size, output_size):
                self.weight = torch.randn(input_size, output_size)
                self.bias = torch.zeros(output_size)

            def __call__(self, x):
                return x @ self.weights + self.bias
        """
        print(f"\n{example_2}")
        old_tokens = old_tokenizer.tokenize(example_2)
        print(f"old: {old_tokens}")
        new_tokens = new_tokenizer.tokenize(example_2)
        print(f"new: {new_tokens}")

Expected behavior

The function train_new_from_iterator works as expected when training a new tokenizer from a gpt2 tokenizer as demonstrated in the example, but does not work for training a new tokenizer from a Llama-2 tokenizer.

With the code snippet above, training a tokenizer from gpt2 gives the output:

Example 1:
def add_numbers(a, b):
  """Add the two numbers `a` and `b`."""
  return a + b
        
old: ['Ċ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġdef', 'Ġadd', '_', 'n', 'umbers', '(', 'a', ',', 'Ġb', '):', 'Ċ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ"""', 'Add', 'Ġthe', 'Ġtwo', 'Ġnumbers', 'Ġ`', 'a', '`', 'Ġand', 'Ġ`', 'b', '`', '."', '""', 'Ċ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġreturn', 'Ġa', 'Ġ+', 'Ġb', 'Ċ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ']
new: ['ĊĠĠĠĠĠĠĠ', 'Ġdef', 'Ġadd', '_', 'numbers', '(', 'a', ',', 'Ġb', '):', 'ĊĠĠĠĠĠĠĠĠĠĠĠ', 'Ġ"""', 'Add', 'Ġthe', 'Ġtwo', 'Ġnumbers', 'Ġ`', 'a', '`', 'Ġand', 'Ġ`', 'b', '`."""', 'ĊĠĠĠĠĠĠĠĠĠĠĠ', 'Ġreturn', 'Ġa', 'Ġ+', 'Ġb', 'ĊĠĠĠĠĠĠĠĠ']

Example 2:
class LinearLayer():
  def __init__(self, input_size, output_size):
    self.weight = torch.randn(input_size, output_size)
    self.bias = torch.zeros(output_size)

  def __call__(self, x):
    return x @ self.weights + self.bias
        
old: ['Ċ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġclass', 'ĠLinear', 'Layer', '():', 'Ċ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġdef', 'Ġ__', 'init', '__', '(', 'self', ',', 'Ġinput', '_', 'size', ',', 'Ġoutput', '_', 'size', '):', 'Ċ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġself', '.', 'weight', 'Ġ=', 'Ġtorch', '.', 'rand', 'n', '(', 'input', '_', 'size', ',', 'Ġoutput', '_', 'size', ')', 'Ċ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġself', '.', 'b', 'ias', 'Ġ=', 'Ġtorch', '.', 'zer', 'os', '(', 'output', '_', 'size', ')', 'ĊĊ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġdef', 'Ġ__', 'call', '__', '(', 'self', ',', 'Ġx', '):', 'Ċ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġreturn', 'Ġx', 'Ġ@', 'Ġself', '.', 'weights', 'Ġ+', 'Ġself', '.', 'b', 'ias', 'Ċ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ', 'Ġ']
new: ['ĊĠĠĠĠĠĠĠ', 'Ġclass', 'ĠLinear', 'Layer', '():', 'ĊĠĠĠĠĠĠĠĠĠĠĠ', 'Ġdef', 'Ġ__', 'init', '__(', 'self', ',', 'Ġinput', '_', 'size', ',', 'Ġoutput', '_', 'size', '):', 'ĊĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ', 'Ġself', '.', 'weight', 'Ġ=', 'Ġtorch', '.', 'randn', '(', 'input', '_', 'size', ',', 'Ġoutput', '_', 'size', ')', 'ĊĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ', 'Ġself', '.', 'bias', 'Ġ=', 'Ġtorch', '.', 'zeros', '(', 'output', '_', 'size', ')', 'ĊĊĠĠĠĠĠĠĠĠĠĠĠ', 'Ġdef', 'Ġ__', 'call', '__(', 'self', ',', 'Ġx', '):', 'ĊĠĠĠĠĠĠĠĠĠĠĠĠĠĠĠ', 'Ġreturn', 'Ġx', 'Ġ@', 'Ġself', '.', 'weights', 'Ġ+', 'Ġself', '.', 'bias', 'ĊĠĠĠĠĠĠĠĠ']

However, training Llama-2's tokenizer gives:

Example 1:
def add_numbers(a, b):
  """Add the two numbers `a` and `b`."""
  return a + b
        
old: ['▁', '<0x0A>', '▁▁▁▁▁▁▁', '▁def', '▁add', '_', 'numbers', '(', 'a', ',', '▁b', '):', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"""', 'Add', '▁the', '▁two', '▁numbers', '▁`', 'a', '`', '▁and', '▁`', 'b', '`', '."', '""', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁return', '▁a', '▁+', '▁b', '<0x0A>', '▁▁▁▁▁▁▁▁']
new: ['▁', '\n▁▁▁▁▁▁▁▁def▁', 'add_', 'number', 's(', 'a,▁b', '):\n▁▁▁▁▁▁▁▁▁▁▁▁"""', 'Add▁the▁', 'two▁', 'number', 's▁`', 'a', '`▁and▁`', 'b', '`', '."""', '\n▁▁▁▁▁▁▁▁▁▁▁▁return▁', 'a▁+▁', 'b', '\n▁▁▁▁▁▁▁▁']

Example 2:
class LinearLayer():
  def __init__(self, input_size, output_size):
    self.weight = torch.randn(input_size, output_size)
    self.bias = torch.zeros(output_size)

  def __call__(self, x):
    return x @ self.weights + self.bias
        
old: ['▁', '<0x0A>', '▁▁▁▁▁▁▁', '▁class', '▁Linear', 'Layer', '():', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁def', '▁__', 'init', '__(', 'self', ',', '▁input', '_', 'size', ',', '▁output', '_', 'size', '):', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁self', '.', 'weight', '▁=', '▁tor', 'ch', '.', 'rand', 'n', '(', 'input', '_', 'size', ',', '▁output', '_', 'size', ')', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁self', '.', 'b', 'ias', '▁=', '▁tor', 'ch', '.', 'zer', 'os', '(', 'output', '_', 'size', ')', '<0x0A>', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁def', '▁__', 'call', '__(', 'self', ',', '▁x', '):', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁return', '▁x', '▁@', '▁self', '.', 'we', 'ights', '▁+', '▁self', '.', 'b', 'ias', '<0x0A>', '▁▁▁▁▁▁▁▁']
new: ['▁', '\n▁▁▁▁▁▁▁▁', 'class▁', 'Linear', 'Layer(', '):\n▁▁▁▁▁▁▁▁▁▁▁▁', 'def▁__init__(self,▁', 'input_', 'size,▁', 'output_', 'size', '):\n▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁self.', 'weight▁=▁', 'torch', '.r', 'and', 'n(', 'input_', 'size,▁', 'output_', 'size', ')\n▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁self.', 'bi', 'as▁=▁', 'torch.', 'zeros(', 'output_', 'size', ')\n\n▁▁▁▁▁▁▁▁▁▁▁▁', 'def▁__', 'call__', '(self,▁x', '):\n▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁return▁', 'x▁', '@▁', 'self.', 'weight', 's▁+▁', 'self.', 'bias', '\n▁▁▁▁▁▁▁▁']

The underscores _ should be prepended at the front of new words, but it seems to be inserted at the back of words or in between words. In fact, it seems like the retrained tokenizer is worse than the original tokenizer on the new data.

@ArthurZucker
Copy link
Collaborator

Ahhh I'll have a look that looks a bit nasty indeed

@larrylawl
Copy link

Hi @ArthurZucker , any updates on this? Thank you!

@ArthurZucker
Copy link
Collaborator

Hey, I can't reproduce this yet. I don't have your local dataset, and I don't have the loading script so

def python_generator():
    # Load local files for code_search_net/python
    # https://huggingface.co/datasets/code_search_net
    dataset = load_dataset("code_search_net/python.py", "python")
    dataset = dataset["train"]
    for start_idx in range(0, len(dataset), 1000):
        samples = dataset[start_idx: start_idx + 1000]
        yield samples["whole_func_string"]

fails with
FileNotFoundError: Couldn't find a dataset script at /Users/arthurzucker/Work/transformers/deci-7b/code_search_net/python.py

@ArthurZucker
Copy link
Collaborator

I cannot help you without a proper reproducer

@ArthurZucker
Copy link
Collaborator

One thing that is certain is that Bytefallback does not seem to be activated (properly) because the bytes should be part of the vocab, the trainer should have a logic to handle that which it does not at the moment

@phoongkhangzhie
Copy link
Author

I cannot help you without a proper reproducer

I've updated the script above. Hopefully it works now!

@anderleich
Copy link

Same here! There are tokens in the vocabulary that consist of some joined words, like this▁is▁a▁test

@ArthurZucker
Copy link
Collaborator

What did you train your tokenizer on?

@ArthurZucker
Copy link
Collaborator

@phoongkhangzhie I had to update your script it does not work out of the box,

@anderleich
Copy link

@ArthurZucker on batches of strings. It seems it's not splitting words

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Dec 19, 2023

I think a quick fix would be to disable the normalizer and use a metaspace pre-tokenizer instead.

from tokenizers import pre_tokenizers, normalizers
from transformers import AutoTokenizer
old_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
old_tokenizer._tokenizer.normalizer = normalizers.Sequence([])
old_tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace("▁", True, prepend_scheme = "first")

@anderleich
Copy link

It works, the vocabulary is correctly generated now. However, it does not pretokenize punctuation:

(Pdb) old_tokenizer.convert_ids_to_tokens(old_tokenizer("This is a test.")["input_ids"])
['<s>', '▁This', '▁is', '▁a', '▁test', '.']
(Pdb) new_tokenizer.convert_ids_to_tokens(new_tokenizer("This is a test.")["input_ids"])
['<s>', '▁Th', 'is', '▁is', '▁a', '▁tes', 't.']

@ArthurZucker
Copy link
Collaborator

That's because it is probably missing a replace normalizer. so something like this:

from tokenizers import pre_tokenizers, normalizers
from transformers import AutoTokenizer
old_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
old_tokenizer._tokenizer.normalizer = normalizers.Sequence([normalizers.Strip(left=False, right=True), normalizers.Replace(Regex(" {2,}"), "▁")])
old_tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace("▁", True, prepend_scheme = "first")

(make sure you don't use "_" but "▁"

@ArthurZucker
Copy link
Collaborator

#26678 should provide the fix.
cc @xenova as this seems to give us a headache hahaa

@anderleich
Copy link

I've added the noramlizer as you said. I solves the final dot issue. However, inner punctuation is not tokenized. There are tokens like ▁(house) in the final vocabulary: I think we need to add pre_tokenizers.Punctuation() in the pre_tokenizers:

old_tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
old_tokenizer._tokenizer.normalizer = normalizers.Sequence([normalizers.Strip(left=False, right=True), normalizers.Replace(tokenizers.Regex(" {2,}"), "▁")])
old_tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Sequence([pre_tokenizers.Punctuation(), pre_tokenizers.Metaspace(prepend_scheme="first")])

@phoongkhangzhie
Copy link
Author

Thank you @ArthurZucker and @anderleich for your inputs.

I think there are still issues with the tokenizer even after the various fixes.

I think a quick fix would be to disable the normalizer and use a metaspace pre-tokenizer instead.

from tokenizers import pre_tokenizers, normalizers
from transformers import AutoTokenizer
old_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
old_tokenizer._tokenizer.normalizer = normalizers.Sequence([])
old_tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace("▁", True, prepend_scheme = "first")

With the above fix, the outputs are:

Example 1:
def add_numbers(a, b):
    """Add the two numbers `a` and `b`."""
    return a + b

old: ['▁', '<0x0A>', '▁▁▁▁▁▁▁', '▁def', '▁add', '_', 'numbers', '(', 'a', ',', '▁b', '):', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"""', 'Add', '▁the', '▁two', '▁numbers', '▁`', 'a', '`', '▁and', '▁`', 'b', '`', '."', '""', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁return', '▁a', '▁+', '▁b', '<0x0A>', '▁▁▁▁▁▁▁▁']
new: ['▁\n', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁def', '▁add_', 'number', 's(', 'a,', '▁b):\n', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁"""Add', '▁the', '▁two', '▁numbers', '▁`a`', '▁and', '▁`b', '`."""\n', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁return', '▁a', '▁+', '▁b\n', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁']

Example 2:
class LinearLayer():
    def __init__(self, input_size, output_size):
        self.weight = torch.randn(input_size, output_size)
        self.bias = torch.zeros(output_size)

    def __call__(self, x):
        return x @ self.weights + self.bias

old: ['▁', '<0x0A>', '▁▁▁▁▁▁▁', '▁class', '▁Linear', 'Layer', '():', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁def', '▁__', 'init', '__(', 'self', ',', '▁input', '_', 'size', ',', '▁output', '_', 'size', '):', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁self', '.', 'weight', '▁=', '▁tor', 'ch', '.', 'rand', 'n', '(', 'input', '_', 'size', ',', '▁output', '_', 'size', ')', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁self', '.', 'b', 'ias', '▁=', '▁tor', 'ch', '.', 'zer', 'os', '(', 'output', '_', 'size', ')', '<0x0A>', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁def', '▁__', 'call', '__(', 'self', ',', '▁x', '):', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁return', '▁x', '▁@', '▁self', '.', 'we', 'ights', '▁+', '▁self', '.', 'b', 'ias', '<0x0A>', '▁▁▁▁▁▁▁▁']
new: ['▁\n', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁class', '▁Linear', 'Layer', '():\n', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁def', '▁__init__(self,', '▁input', '_size,', '▁output', '_size):\n', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁self.weight', '▁=', '▁torch', '.randn', '(input', '_size,', '▁output', '_size)\n', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁self.b', 'ias', '▁=', '▁torch.', 'zeros(', 'output', '_size)\n\n', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁def', '▁__call', '__(self,', '▁x):\n', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁return', '▁x', '▁@', '▁self.', 'weights', '▁+', '▁self.b', 'ias', '\n', '▁', '▁', '▁', '▁', '▁', '▁', '▁', '▁']

This fix prepends all whitespace characters with '▁', but all of them are separate tokens in the final output where instead some of them should be merged instead to represent indentations or double indentation in code. Also, the newline character \n is not treated as a whitespace character.

That's because it is probably missing a replace normalizer. so something like this:

from tokenizers import pre_tokenizers, normalizers
from transformers import AutoTokenizer
old_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
old_tokenizer._tokenizer.normalizer = normalizers.Sequence([normalizers.Strip(left=False, right=True), normalizers.Replace(Regex(" {2,}"), "▁")])
old_tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace("▁", True, prepend_scheme = "first")

(make sure you don't use "_" but "▁"

With the above fix, the outputs are:

Example 1:
def add_numbers(a, b):
    """Add the two numbers `a` and `b`."""
    return a + b

old: ['▁', '<0x0A>', '▁▁▁▁▁▁▁', '▁def', '▁add', '_', 'numbers', '(', 'a', ',', '▁b', '):', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"""', 'Add', '▁the', '▁two', '▁numbers', '▁`', 'a', '`', '▁and', '▁`', 'b', '`', '."', '""', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁return', '▁a', '▁+', '▁b', '<0x0A>', '▁▁▁▁▁▁▁▁']
new: ['▁\n', '▁def', '▁add_', 'number', 's(', 'a,', '▁b):\n', '▁"""Add', '▁the', '▁two', '▁numbers', '▁`a`', '▁and', '▁`b', '`."""\n', '▁return', '▁a', '▁+', '▁b']

Example 2:
class LinearLayer():
    def __init__(self, input_size, output_size):
        self.weight = torch.randn(input_size, output_size)
        self.bias = torch.zeros(output_size)

    def __call__(self, x):
        return x @ self.weights + self.bias

old: ['▁', '<0x0A>', '▁▁▁▁▁▁▁', '▁class', '▁Linear', 'Layer', '():', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁def', '▁__', 'init', '__(', 'self', ',', '▁input', '_', 'size', ',', '▁output', '_', 'size', '):', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁self', '.', 'weight', '▁=', '▁tor', 'ch', '.', 'rand', 'n', '(', 'input', '_', 'size', ',', '▁output', '_', 'size', ')', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁self', '.', 'b', 'ias', '▁=', '▁tor', 'ch', '.', 'zer', 'os', '(', 'output', '_', 'size', ')', '<0x0A>', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁def', '▁__', 'call', '__(', 'self', ',', '▁x', '):', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁return', '▁x', '▁@', '▁self', '.', 'we', 'ights', '▁+', '▁self', '.', 'b', 'ias', '<0x0A>', '▁▁▁▁▁▁▁▁']
new: ['▁\n', '▁class', '▁Linear', 'Layer', '():\n', '▁def', '▁__init__(self,', '▁input', '_size,', '▁output', '_size):\n', '▁self.weight', '▁=', '▁torch', '.randn', '(input', '_size,', '▁output', '_size)\n', '▁self.b', 'ias', '▁=', '▁torch.', 'zeros(', 'output', '_size)\n\n', '▁def', '▁__call', '__(self,', '▁x):\n', '▁return', '▁x', '▁@', '▁self.', 'weights', '▁+', '▁self.b', 'ias']

This fix collapses all the whitespace characters into a single '▁' character. However, this removes the importance of whitespace in code such as the different indentation levels. Again, the newline character \n is not treated as a whitespace character.

I've added the noramlizer as you said. I solves the final dot issue. However, inner punctuation is not tokenized. There are tokens like ▁(house) in the final vocabulary: I think we need to add pre_tokenizers.Punctuation() in the pre_tokenizers:

old_tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
old_tokenizer._tokenizer.normalizer = normalizers.Sequence([normalizers.Strip(left=False, right=True), normalizers.Replace(tokenizers.Regex(" {2,}"), "▁")])
old_tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Sequence([pre_tokenizers.Punctuation(), pre_tokenizers.Metaspace(prepend_scheme="first")])

And with this fix, the outputs are:

Example 1:
def add_numbers(a, b):
    """Add the two numbers `a` and `b`."""
    return a + b

old: ['▁', '<0x0A>', '▁▁▁▁▁▁▁', '▁def', '▁add', '_', 'numbers', '(', 'a', ',', '▁b', '):', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"""', 'Add', '▁the', '▁two', '▁numbers', '▁`', 'a', '`', '▁and', '▁`', 'b', '`', '."', '""', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁return', '▁a', '▁+', '▁b', '<0x0A>', '▁▁▁▁▁▁▁▁']
new: ['▁\n', '▁def', '▁add', '_', 'numbers', '(', 'a', ',', '▁b', ')', ':', '\n', '▁', '"', '"', '"', 'Add', '▁the', '▁two', '▁numbers', '▁', '`', 'a', '`', '▁and', '▁', '`', 'b', '`', '.', '"', '"', '"', '\n', '▁return', '▁a', '▁', '+', '▁b']

Example 2:
class LinearLayer():
    def __init__(self, input_size, output_size):
        self.weight = torch.randn(input_size, output_size)
        self.bias = torch.zeros(output_size)

    def __call__(self, x):
        return x @ self.weights + self.bias

old: ['▁', '<0x0A>', '▁▁▁▁▁▁▁', '▁class', '▁Linear', 'Layer', '():', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁def', '▁__', 'init', '__(', 'self', ',', '▁input', '_', 'size', ',', '▁output', '_', 'size', '):', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁self', '.', 'weight', '▁=', '▁tor', 'ch', '.', 'rand', 'n', '(', 'input', '_', 'size', ',', '▁output', '_', 'size', ')', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁self', '.', 'b', 'ias', '▁=', '▁tor', 'ch', '.', 'zer', 'os', '(', 'output', '_', 'size', ')', '<0x0A>', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁def', '▁__', 'call', '__(', 'self', ',', '▁x', '):', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁return', '▁x', '▁@', '▁self', '.', 'we', 'ights', '▁+', '▁self', '.', 'b', 'ias', '<0x0A>', '▁▁▁▁▁▁▁▁']
new: ['▁\n', '▁class', '▁Linear', 'Layer', '(', ')', ':', '\n', '▁def', '▁', '_', '_', 'init', '_', '_', '(', 'self', ',', '▁input', '_', 'size', ',', '▁output', '_', 'size', ')', ':', '\n', '▁self', '.', 'weight', '▁', '=', '▁torch', '.', 'randn', '(', 'input', '_', 'size', ',', '▁output', '_', 'size', ')', '\n', '▁self', '.', 'bias', '▁', '=', '▁torch', '.', 'zeros', '(', 'output', '_', 'size', ')', '\n\n', '▁def', '▁', '_', '_', 'call', '_', '_', '(', 'self', ',', '▁x', ')', ':', '\n', '▁return', '▁x', '▁', '@', '▁self', '.', 'weights', '▁', '+', '▁self', '.', 'bias']

While this tokenization might be better than the above one, I think it is too aggressive with the splitting of the punctuation. Like the above fixes, the newline character \n is not treated as a whitespace character.

Ideally, the outputs should be like this (similar to the GPT2 tokenization):

Example 1:
def add_numbers(a, b):
    """Add the two numbers `a` and `b`."""
    return a + b

old: ['▁', '<0x0A>', '▁▁▁▁▁▁▁', '▁def', '▁add', '_', 'numbers', '(', 'a', ',', '▁b', '):', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"""', 'Add', '▁the', '▁two', '▁numbers', '▁`', 'a', '`', '▁and', '▁`', 'b', '`', '."', '""', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁return', '▁a', '▁+', '▁b', '<0x0A>', '▁▁▁▁▁▁▁▁']
new: ['▁\n', '▁def', '▁add', '_', 'numbers', '(', 'a', ',', '▁b', ')', ':', '▁\n', '▁"""', 'Add', '▁the', '▁two', '▁numbers', '▁', '`', 'a', '`', '▁and', '▁', '`', 'b', '`', '."""', '▁\n', '▁return', '▁a', '▁+', '▁b']

Example 2:
class LinearLayer():
    def __init__(self, input_size, output_size):
        self.weight = torch.randn(input_size, output_size)
        self.bias = torch.zeros(output_size)

    def __call__(self, x):
        return x @ self.weights + self.bias

old: ['▁', '<0x0A>', '▁▁▁▁▁▁▁', '▁class', '▁Linear', 'Layer', '():', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁def', '▁__', 'init', '__(', 'self', ',', '▁input', '_', 'size', ',', '▁output', '_', 'size', '):', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁self', '.', 'weight', '▁=', '▁tor', 'ch', '.', 'rand', 'n', '(', 'input', '_', 'size', ',', '▁output', '_', 'size', ')', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁self', '.', 'b', 'ias', '▁=', '▁tor', 'ch', '.', 'zer', 'os', '(', 'output', '_', 'size', ')', '<0x0A>', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁def', '▁__', 'call', '__(', 'self', ',', '▁x', '):', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁return', '▁x', '▁@', '▁self', '.', 'we', 'ights', '▁+', '▁self', '.', 'b', 'ias', '<0x0A>', '▁▁▁▁▁▁▁▁']
new: ['▁\n', '▁class', '▁Linear', 'Layer', '():', '▁\n', '▁def', '▁__', 'init', '__(', 'self', ',', '▁input', '_', 'size', ',', '▁output', '_', 'size', '):', '▁\n', '▁self', '.', 'weight', '▁=', '▁torch', '.', 'randn', '(', 'input', '_', 'size', ',', '▁output', '_', 'size', ')', '▁\n', '▁self', '.', 'bias', '▁=', '▁torch', '.', 'zeros', '(', 'output', '_', 'size', ')', '▁\n\n', '▁def', '▁__', 'call', '__(', 'self', ',', '▁x', '):', '▁\n', '▁return', '▁x', '▁@', '▁self', '.', 'weights', '▁+', '▁self', '.', 'bias']

Will there be any other fixes for this?

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jan 3, 2024

If you want to keep the white space, normalizers.Replace(Regex(" {2,}"), "▁") should not be used indeed.
LeftStripping can be kept, but you would need to also add the bytefallback tokens to the vocab ('<0x0A>' is a new line via bytefallback) if you want it to have the same behaviour!

Regarding the merges, it might be the frequency of the ▁▁▁▁▁▁▁ token that prevents the model from learning it but should not be related to the pre-processing.

So the last issue is probably the bytefallback.

@anderleich
Copy link

@ArthurZucker are there any plans to add all those fixes to the train_new_from_iterator function for Llama2 models?

@ArthurZucker
Copy link
Collaborator

There are plans to add these fixes to the LlamaTokenizer as a whole (specifically the pretokenizer vs normalizer) here #26678. The bytefallback thing needs to be adde to tokenizers and there is a plan but I don't have bandwidth just yet! 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants