Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ByteFallback support for tokenizers. #1183

Merged
merged 11 commits into from
Mar 23, 2023
Merged

Adding ByteFallback support for tokenizers. #1183

merged 11 commits into from
Mar 23, 2023

Conversation

Narsil
Copy link
Collaborator

@Narsil Narsil commented Mar 17, 2023

Two items added:

  • A flag byte_fallback for the BPE model. This will be in charge
    of using <0x61> instead of unk on unknown tokens.
  • A ByteFallback decoder, which will be in charge of putting everything
    back into string whenever possible. Showing � when the byte decoding
    fails (behavior checked against LlamaTokenizer in transformers.

Fixes #929

Two items added:

- A flag `byte_fallback` for the `BPE` model. This will be in charge
  of using `<0x61>` instead of unk on unknown tokens.
- A ByteFallback decoder, which will be in charge of putting everything
  back into string whenever possible. Showing � when the byte decoding
  fails (behavior checked against LlamaTokenizer in `transformers`.
Copy link
Contributor

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this so quickly!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 17, 2023

The documentation is not available anymore as the PR was closed or merged.

bindings/python/test.py Outdated Show resolved Hide resolved
Narsil added a commit to Narsil/transformers that referenced this pull request Mar 20, 2023
- Requires huggingface/tokenizers#1183 version
- Only support byte_fallback for llama, raise otherwise (safety net).
- Lots of questions are special tokens

How to test:

```python

from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers import AutoTokenizer
from tokenizers import Tokenizer

tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")

if False:
    new_tokenizer = Tokenizer.from_file("tok.json")
else:
    new_tokenizer = convert_slow_tokenizer(tokenizer)
    new_tokenizer.save("tok.json")

strings = [
    "This is a test",
    "生活的真谛是",
    "生活的真谛是[MASK]。",
    # XXX: This one is problematic because of special tokens
    # "<s> Something something",
]

for string in strings:
    encoded = tokenizer(string)["input_ids"]
    encoded2 = new_tokenizer.encode(string).ids

    assert encoded == encoded2, f"{encoded} != {encoded2}"

    decoded = tokenizer.decode(encoded)
    decoded2 = new_tokenizer.decode(encoded2)

    assert decoded.strip() == decoded2, f"{repr(decoded)} != {repr(decoded2)}"
```
@theblackcat102
Copy link

Can confirm working with byte_fallback set to true for llama sentence piece

>>> from tokenizers import Tokenizer
>>> tokenizer = Tokenizer.from_pretrained("theblackcat102/llama-7b-test")

without byte fallback

>>> tokenizer.encode("你好啊測試測試").ids
[29871, 30919, 31076, 0]

with byte fallback

>>> tokenizer.encode("你好啊測試測試").ids
[29871, 30919, 31076, 232, 152, 141, 233, 187, 175, 235, 172, 169, 233, 187, 175, 235, 172, 169]

@Narsil
Copy link
Collaborator Author

Narsil commented Mar 22, 2023

The bye fallback is working as advertised.

However I still find some differences in tokenization between llama spm and tokenizers after converting one to the other.

It's usually a simple matter of proper configuration, but this particular model is showcasing odd issues, which require much lower level investigation than usual. Hang tight.

@Narsil Narsil merged commit 73637a0 into main Mar 23, 2023
@Narsil Narsil deleted the byte_fallback branch March 23, 2023 15:04
Narsil added a commit to Narsil/transformers that referenced this pull request Apr 4, 2023
- Requires huggingface/tokenizers#1183 version
- Only support byte_fallback for llama, raise otherwise (safety net).
- Lots of questions are special tokens

How to test:

```python

from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers import AutoTokenizer
from tokenizers import Tokenizer

tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")

if False:
    new_tokenizer = Tokenizer.from_file("tok.json")
else:
    new_tokenizer = convert_slow_tokenizer(tokenizer)
    new_tokenizer.save("tok.json")

strings = [
    "This is a test",
    "生活的真谛是",
    "生活的真谛是[MASK]。",
    # XXX: This one is problematic because of special tokens
    # "<s> Something something",
]

for string in strings:
    encoded = tokenizer(string)["input_ids"]
    encoded2 = new_tokenizer.encode(string).ids

    assert encoded == encoded2, f"{encoded} != {encoded2}"

    decoded = tokenizer.decode(encoded)
    decoded2 = new_tokenizer.decode(encoded2)

    assert decoded.strip() == decoded2, f"{repr(decoded)} != {repr(decoded2)}"
```
Narsil added a commit to Narsil/transformers that referenced this pull request Apr 5, 2023
- Requires huggingface/tokenizers#1183 version
- Only support byte_fallback for llama, raise otherwise (safety net).
- Lots of questions are special tokens

How to test:

```python

from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers import AutoTokenizer
from tokenizers import Tokenizer

tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")

if False:
    new_tokenizer = Tokenizer.from_file("tok.json")
else:
    new_tokenizer = convert_slow_tokenizer(tokenizer)
    new_tokenizer.save("tok.json")

strings = [
    "This is a test",
    "生活的真谛是",
    "生活的真谛是[MASK]。",
    # XXX: This one is problematic because of special tokens
    # "<s> Something something",
]

for string in strings:
    encoded = tokenizer(string)["input_ids"]
    encoded2 = new_tokenizer.encode(string).ids

    assert encoded == encoded2, f"{encoded} != {encoded2}"

    decoded = tokenizer.decode(encoded)
    decoded2 = new_tokenizer.decode(encoded2)

    assert decoded.strip() == decoded2, f"{repr(decoded)} != {repr(decoded2)}"
```

The converter + some test script.

The test script.

Tmp save.

Adding Fast tokenizer + tests.

Adding the tokenization tests.

Correct combination.

Small fix.

Fixing tests.

Fixing with latest update.

Rebased.

fix copies + normalized added tokens  + copies.

Adding doc.

TMP.

Doc + split files.

Doc.

Versions + try import.

Fix Camembert + warnings -> Error.

Fix by ArthurZucker.

Not a decorator.
Narsil added a commit to huggingface/transformers that referenced this pull request Apr 6, 2023
* Adding Llama FastTokenizer support.

- Requires huggingface/tokenizers#1183 version
- Only support byte_fallback for llama, raise otherwise (safety net).
- Lots of questions are special tokens

How to test:

```python

from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers import AutoTokenizer
from tokenizers import Tokenizer

tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")

if False:
    new_tokenizer = Tokenizer.from_file("tok.json")
else:
    new_tokenizer = convert_slow_tokenizer(tokenizer)
    new_tokenizer.save("tok.json")

strings = [
    "This is a test",
    "生活的真谛是",
    "生活的真谛是[MASK]。",
    # XXX: This one is problematic because of special tokens
    # "<s> Something something",
]

for string in strings:
    encoded = tokenizer(string)["input_ids"]
    encoded2 = new_tokenizer.encode(string).ids

    assert encoded == encoded2, f"{encoded} != {encoded2}"

    decoded = tokenizer.decode(encoded)
    decoded2 = new_tokenizer.decode(encoded2)

    assert decoded.strip() == decoded2, f"{repr(decoded)} != {repr(decoded2)}"
```

The converter + some test script.

The test script.

Tmp save.

Adding Fast tokenizer + tests.

Adding the tokenization tests.

Correct combination.

Small fix.

Fixing tests.

Fixing with latest update.

Rebased.

fix copies + normalized added tokens  + copies.

Adding doc.

TMP.

Doc + split files.

Doc.

Versions + try import.

Fix Camembert + warnings -> Error.

Fix by ArthurZucker.

Not a decorator.

* Fixing comments.

* Adding more to docstring.

* Doc rewriting.
xloem pushed a commit to xloem/transformers that referenced this pull request Apr 9, 2023
* Adding Llama FastTokenizer support.

- Requires huggingface/tokenizers#1183 version
- Only support byte_fallback for llama, raise otherwise (safety net).
- Lots of questions are special tokens

How to test:

```python

from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers import AutoTokenizer
from tokenizers import Tokenizer

tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")

if False:
    new_tokenizer = Tokenizer.from_file("tok.json")
else:
    new_tokenizer = convert_slow_tokenizer(tokenizer)
    new_tokenizer.save("tok.json")

strings = [
    "This is a test",
    "生活的真谛是",
    "生活的真谛是[MASK]。",
    # XXX: This one is problematic because of special tokens
    # "<s> Something something",
]

for string in strings:
    encoded = tokenizer(string)["input_ids"]
    encoded2 = new_tokenizer.encode(string).ids

    assert encoded == encoded2, f"{encoded} != {encoded2}"

    decoded = tokenizer.decode(encoded)
    decoded2 = new_tokenizer.decode(encoded2)

    assert decoded.strip() == decoded2, f"{repr(decoded)} != {repr(decoded2)}"
```

The converter + some test script.

The test script.

Tmp save.

Adding Fast tokenizer + tests.

Adding the tokenization tests.

Correct combination.

Small fix.

Fixing tests.

Fixing with latest update.

Rebased.

fix copies + normalized added tokens  + copies.

Adding doc.

TMP.

Doc + split files.

Doc.

Versions + try import.

Fix Camembert + warnings -> Error.

Fix by ArthurZucker.

Not a decorator.

* Fixing comments.

* Adding more to docstring.

* Doc rewriting.
@kjtaed
Copy link

kjtaed commented Apr 11, 2023

The "byte_fallback" option does not decompose unknown UTF-8 characters into bytes.

Is there example of a training code using byte_fallback?

tokenizer = Tokenizer(BPE(byte_fallback=True))
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
    pre_tokenizers.Digits(individual_digits=True),
    pre_tokenizers.Metaspace(),
])
tokenizer.decoder = decoders.Sequence([
    decoders.Metaspace(),
    decoders.ByteFallback(),
])
trainer = trainers.BpeTrainer(
    vocab_size=vocab_size,
    min_frequency=min_frequency,
    special_tokens=special_tokens,
)
tokenizer.train(files, trainer=trainer)

@Narsil
Copy link
Collaborator Author

Narsil commented Apr 11, 2023

The "byte_fallback" option does not decompose unknown UTF-8 characters into bytes.

No, it transforms unknown tokens (unk) into it's <0x0a> tokens (if they exist in the vocab) (for each byte in the unk token)

@kjtaed
Copy link

kjtaed commented Apr 11, 2023

How do I add byte tokens to vocab?

@Narsil
Copy link
Collaborator Author

Narsil commented Apr 11, 2023

https://huggingface.co/docs/tokenizers/api/trainers

initial_alphabet=[f'<0x{:02x}>' for x in range(256)] for instance ? (You may need to add more depending on the rest of your setup.

@kjtaed
Copy link

kjtaed commented Apr 11, 2023

I tried it earlier but byte token is not added.

the strings contain more than one character, only the first one is kept.
(https://huggingface.co/docs/tokenizers/api/trainers#tokenizers.trainers.UnigramTrainer.initial_alphabet)

I was able to add byte tokens by modifying model file(json file).
however, it was not possible to add byte tokens using tokenizers api.

@Narsil
Copy link
Collaborator Author

Narsil commented Apr 12, 2023

You are linking the UnigramTrainer here, not the BPE trainer.

Which one are you using ? Unigram bytefallback is not yet supported.

@kjtaed
Copy link

kjtaed commented Apr 13, 2023

Sorry.

initial_alphabet (List[str]) — A list of characters to include in the initial alphabet, even if not seen in the training dataset. If the strings contain more than one character, only the first one is kept.
(https://huggingface.co/docs/tokenizers/api/trainers#tokenizers.trainers.BpeTrainer.initial_alphabet)

Byte tokens(['<0x00>', '<0x01>', ...]) cannot be added with 'initial_alphabet'.

If you have example code for 'training BPE with byte_fallback', can you share it?

@DouglasOrr
Copy link

If it helps, setting special_tokens=[f"<0x{i:02X}>" for i in range(256)] seems to work.

@chris-ha458
Copy link
Contributor

setting them as special tokens do not allow them to be trained in certain situations.
This can be hand changed later, but there should be a way to either

  1. a way to change a lot of flags for special tokens (changing all 256 of them into special =false)
  2. a way to addtokens that wont be 'lost' (if we use the add token method before training, they may or may not be loss)
  3. change byte_fallback behavior to deal with b'\x80' like notation since each are technically "one character"
  4. when byte_fallback is set, internally add <0x00> ~ <0xFF> tokens into the initial alphabet automatically (what sentencepiece does)

I can think of more, but these seem to be the most effective methods.

novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* Adding Llama FastTokenizer support.

- Requires huggingface/tokenizers#1183 version
- Only support byte_fallback for llama, raise otherwise (safety net).
- Lots of questions are special tokens

How to test:

```python

from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers import AutoTokenizer
from tokenizers import Tokenizer

tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")

if False:
    new_tokenizer = Tokenizer.from_file("tok.json")
else:
    new_tokenizer = convert_slow_tokenizer(tokenizer)
    new_tokenizer.save("tok.json")

strings = [
    "This is a test",
    "生活的真谛是",
    "生活的真谛是[MASK]。",
    # XXX: This one is problematic because of special tokens
    # "<s> Something something",
]

for string in strings:
    encoded = tokenizer(string)["input_ids"]
    encoded2 = new_tokenizer.encode(string).ids

    assert encoded == encoded2, f"{encoded} != {encoded2}"

    decoded = tokenizer.decode(encoded)
    decoded2 = new_tokenizer.decode(encoded2)

    assert decoded.strip() == decoded2, f"{repr(decoded)} != {repr(decoded2)}"
```

The converter + some test script.

The test script.

Tmp save.

Adding Fast tokenizer + tests.

Adding the tokenization tests.

Correct combination.

Small fix.

Fixing tests.

Fixing with latest update.

Rebased.

fix copies + normalized added tokens  + copies.

Adding doc.

TMP.

Doc + split files.

Doc.

Versions + try import.

Fix Camembert + warnings -> Error.

Fix by ArthurZucker.

Not a decorator.

* Fixing comments.

* Adding more to docstring.

* Doc rewriting.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement the Byte->char hack of SPM within BPE
8 participants