Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistency between CodeLlamaTokenizer and CodeLlamaTokenizerFast #25881

Closed
4 tasks
rfriel opened this issue Aug 30, 2023 · 11 comments · Fixed by #26678
Closed
4 tasks

Inconsistency between CodeLlamaTokenizer and CodeLlamaTokenizerFast #25881

rfriel opened this issue Aug 30, 2023 · 11 comments · Fixed by #26678

Comments

@rfriel
Copy link

rfriel commented Aug 30, 2023

System Info

  • transformers version: 4.33.0.dev0
  • Platform: Linux-5.15.109+-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.3
  • Accelerate version: 0.22.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1+cu118 (True)
  • Tensorflow version (GPU?): 2.12.0 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.7.2 (gpu)
  • Jax version: 0.4.14
  • JaxLib version: 0.4.14
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer

model = "codellama/CodeLlama-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
tokenizer_fast = AutoTokenizer.from_pretrained(model, use_fast=True)

print(tokenizer.encode("<s>\n", add_special_tokens=False))
# [1, 13]

print(tokenizer_fast.encode("<s>\n", add_special_tokens=False))
# [1, 29871, 13]

# the same issue occurs with any element of `tokenizer.all_special_tokens`, not just <s>

for special_token in tokenizer.all_special_tokens:
    print(special_token)
    print(tokenizer.encode(f"{special_token}\n", add_special_tokens=False))
    print(tokenizer_fast.encode(f"{special_token}\n", add_special_tokens=False))
    print()
    
# <s>
# [1, 13]
# [1, 29871, 13]

# </s>
# [2, 13]
# [2, 29871, 13]

# <unk>
# [0, 13]
# [0, 29871, 13]

# ▁<PRE>
# [32007, 13]
# [32007, 29871, 13]

# ▁<MID>
# [32009, 13]
# [32009, 29871, 13]

# ▁<SUF>
# [32008, 13]
# [32008, 29871, 13]

# ▁<EOT>
# [32010, 13]
# [32010, 29871, 13]

Expected behavior

The two tokenizers should have the same behavior.

There's no exact equivalent of add_special_tokens=False in the original facebookresearch/codellama repo, but the following seems roughly equivalent for the "<PRE>" case:

# assuming repo is cloned at ./codellama and 7b is downloaded

import sys
sys.path.append('codellama')

from llama.tokenizer import Tokenizer

tokenizer_facebookresearch = Tokenizer('codellama/CodeLlama-7b/tokenizer.model')

print(tokenizer_facebookresearch.encode('<PRE>\n', bos=False, eos=False))
# [32007, 13]

which agrees with CodeLlamaTokenizer and disagrees with CodeLlamaTokenizerFast 1.

Footnotes

  1. I realize that one isn't supposed to directly encode "<PRE>" with the HF tokenizer, I'm just using it to construct a case where the HF and Facebook tokenizers can be compared. The Facebook tokenizer won't .encode the EOS or BOS tokens to their corresponding IDs -- it treats them as an ordinary string of 3 characters. But it encodes the FIM tokens to their IDs, as used above with "<PRE>".

@ArthurZucker
Copy link
Collaborator

Yep, this is a known bug and the correct output is the slow output 😉
The related fix was presented in #25224 and currently not propagated to the fast tokenizers. It's on my TODO list!
You can deactivate it by setting legacy = True

@ArthurZucker
Copy link
Collaborator

The fix is in #26678!

@zpx01
Copy link

zpx01 commented Oct 18, 2023

@ArthurZucker When can we expect the fix to go through?

@ArthurZucker
Copy link
Collaborator

Maybe a week or so, this needs a release in tokenizers and a merge in transformers!

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ArthurZucker
Copy link
Collaborator

The PR in tokenizers is ready, I'll try to do a release today or tomorrow. The fix will need a release in transformers but should follow quick 😉

Copy link

github-actions bot commented Dec 8, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ArthurZucker
Copy link
Collaborator

Actually not sure I'll ship it fast enough it needs additional testing

@xenova
Copy link
Contributor

xenova commented Dec 16, 2023

The fix is in #26678!

@ArthurZucker just to confirm, in that PR, although you override the base SpmConverter class, the LlamaConverter itself overrides the normalizer (here) and pre_tokenizer (here), so the changes made there won't fix this problem.

@ArthurZucker
Copy link
Collaborator

Yes, a separate PR will deal with the Llama converter !

@huggingface huggingface deleted a comment from github-actions bot Jan 10, 2024
@ArthurZucker
Copy link
Collaborator

There were delays again but this is not stale!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants