Skip to content

Commit

Permalink
[LlamaSlowConverter] Slow to Fast better support (#29797)
Browse files Browse the repository at this point in the history
* fix

* fix test

* style

* nit

* rather rely on concert token to id

* fix quality

* Update src/transformers/convert_slow_tokenizer.py
  • Loading branch information
ArthurZucker authored and amyeroberts committed Mar 28, 2024
1 parent 02b1012 commit e40fe39
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,9 +1331,9 @@ class LlamaConverter(SpmConverter):

def vocab(self, proto):
vocab = [
("<unk>", 0.0),
("<s>", 0.0),
("</s>", 0.0),
(self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
(self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
(self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
return vocab
Expand Down Expand Up @@ -1371,9 +1371,9 @@ def tokenizer(self, proto):
)
tokenizer.add_special_tokens(
[
AddedToken("<unk>", normalized=False, special=True),
AddedToken("<s>", normalized=False, special=True),
AddedToken("</s>", normalized=False, special=True),
AddedToken(self.original_tokenizer.convert_ids_to_tokens(0), normalized=False, special=True),
AddedToken(self.original_tokenizer.convert_ids_to_tokens(1), normalized=False, special=True),
AddedToken(self.original_tokenizer.convert_ids_to_tokens(2), normalized=False, special=True),
]
)
else:
Expand Down
27 changes: 27 additions & 0 deletions tests/models/llava/test_modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from transformers import (
AutoProcessor,
AutoTokenizer,
LlavaConfig,
LlavaForConditionalGeneration,
is_torch_available,
Expand Down Expand Up @@ -575,3 +576,29 @@ def test_llava_merge_inputs_error_bug(self):
labels=input_ids,
).loss
loss.backward()

def test_tokenizer_integration(self):
slow_tokenizer = AutoTokenizer.from_pretrained("liuhaotian/llava-v1.6-34b", use_fast=False)
slow_tokenizer.add_tokens("<image>", True)

fast_tokenizer = AutoTokenizer.from_pretrained(
"liuhaotian/llava-v1.6-34b",
bos_token="<|startoftext|>",
eos_token="<|endoftext|>",
from_slow=True,
legacy=False,
)
fast_tokenizer.add_tokens("<image>", True)

prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"
# If the token is added as special, it's not normalized, and the only diff is the extra space after special tokens.
# https://github.com/huggingface/transformers/pull/28881 is the fix for this.
self.assertEqual(
slow_tokenizer.tokenize(prompt),
['<|im_start|>', 'system', '\n', 'Answer', '▁the', '▁questions', '.', '<|im_end|>', '<|im_start|>', 'user', '\n', '<image>', '\n', 'What', '▁is', '▁shown', '▁in', '▁this', '▁image', '?', '<|im_end|>', '<|im_start|>', 'ass', 'istant', '\n']
) # fmt: skip

self.assertEqual(
fast_tokenizer.tokenize(prompt),
['<|im_start|>', '▁system', '\n', 'Answer', '▁the', '▁questions', '.', '<|im_end|>', '<|im_start|>', '▁user', '\n', '<image>', '▁', '\n', 'What', '▁is', '▁shown', '▁in', '▁this', '▁image', '?', '<|im_end|>', '<|im_start|>', '▁assistant', '\n']
) # fmt: skip

0 comments on commit e40fe39

Please sign in to comment.