Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LlamaSlowConverter] Slow to Fast better support #29797

Merged
merged 8 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,9 +1331,9 @@ class LlamaConverter(SpmConverter):

def vocab(self, proto):
vocab = [
("<unk>", 0.0),
("<s>", 0.0),
("</s>", 0.0),
(self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
(self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
(self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
Comment on lines +1334 to +1336
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming having the mapping of ids to tokens here is fine because we're always handling Llama in this case?

]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
return vocab
Expand Down Expand Up @@ -1371,9 +1371,9 @@ def tokenizer(self, proto):
)
tokenizer.add_special_tokens(
[
AddedToken("<unk>", normalized=False, special=True),
AddedToken("<s>", normalized=False, special=True),
AddedToken("</s>", normalized=False, special=True),
AddedToken(self.original_tokenizer.convert_ids_to_tokens(0), normalized=False, special=True),
AddedToken(self.original_tokenizer.convert_ids_to_tokens(1), normalized=False, special=True),
AddedToken(self.original_tokenizer.convert_ids_to_tokens(2), normalized=False, special=True),
]
)
else:
Expand Down
27 changes: 27 additions & 0 deletions tests/models/llava/test_modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from transformers import (
AutoProcessor,
AutoTokenizer,
LlavaConfig,
LlavaForConditionalGeneration,
is_torch_available,
Expand Down Expand Up @@ -575,3 +576,29 @@ def test_llava_merge_inputs_error_bug(self):
labels=input_ids,
).loss
loss.backward()

def test_tokenizer_integration(self):
slow_tokenizer = AutoTokenizer.from_pretrained("liuhaotian/llava-v1.6-34b", use_fast=False)
slow_tokenizer.add_tokens("<image>", True)

fast_tokenizer = AutoTokenizer.from_pretrained(
"liuhaotian/llava-v1.6-34b",
bos_token="<|startoftext|>",
eos_token="<|endoftext|>",
from_slow=True,
legacy=False,
)
fast_tokenizer.add_tokens("<image>", True)

prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"
# If the token is added as special, it's not normalized, and the only diff is the extra space after special tokens.
# https://github.com/huggingface/transformers/pull/28881 is the fix for this.
self.assertEqual(
slow_tokenizer.tokenize(prompt),
['<|im_start|>', 'system', '\n', 'Answer', '▁the', '▁questions', '.', '<|im_end|>', '<|im_start|>', 'user', '\n', '<image>', '\n', 'What', '▁is', '▁shown', '▁in', '▁this', '▁image', '?', '<|im_end|>', '<|im_start|>', 'ass', 'istant', '\n']
) # fmt: skip

self.assertEqual(
fast_tokenizer.tokenize(prompt),
['<|im_start|>', '▁system', '\n', 'Answer', '▁the', '▁questions', '.', '<|im_end|>', '<|im_start|>', '▁user', '\n', '<image>', '▁', '\n', 'What', '▁is', '▁shown', '▁in', '▁this', '▁image', '?', '<|im_end|>', '<|im_start|>', '▁assistant', '\n']
) # fmt: skip
Loading