Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LlamaTokenizerFast] nit update post_processor on the fly #23855

Merged
merged 7 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/transformers/models/llama/tokenization_llama_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from shutil import copyfile
from typing import Optional, Tuple

from tokenizers import processors

from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import is_sentencepiece_available, logging
from ...utils.versions import require_version
Expand Down Expand Up @@ -84,6 +86,8 @@ def __init__(
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
add_bos_token=True,
add_eos_token=False,
**kwargs,
):
super().__init__(
Expand All @@ -95,10 +99,50 @@ def __init__(
eos_token=eos_token,
**kwargs,
)
self._add_bos_token = add_bos_token
self._add_eos_token = add_eos_token
self.update_post_processor()

self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True

def update_post_processor(self):
bos = self.bos_token
bos_token_id = self.bos_token_id

eos = self.eos_token
eos_token_id = self.eos_token_id

single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') * self.add_eos_token}"
pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') * self.add_eos_token}"

special_tokens = []
if self.add_bos_token:
special_tokens.append((bos, bos_token_id))
if self.add_eos_token:
special_tokens.append((eos, eos_token_id))
self._tokenizer.post_processor = processors.TemplateProcessing(
single=single, pair=pair, special_tokens=special_tokens
)

@property
def add_eos_token(self):
return self._add_eos_token

@property
def add_bos_token(self):
return self._add_bos_token

@add_eos_token.setter
def add_eos_token(self, value):
self._add_eos_token = value
self.update_post_processor()

@add_bos_token.setter
def add_bos_token(self, value):
self._add_bos_token = value
self.update_post_processor()

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
Expand Down
33 changes: 33 additions & 0 deletions tests/models/llama/test_tokenization_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,39 @@ def integration_tests(self):
},
)

def test_fast_special_tokens(self):
slow_tokenizer = self.tokenizer
fast_tokenizer = self.rust_tokenizer
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
assert slow == [1, 319, 4559, 1243]

fast_tokenizer.add_eos_token = False
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
assert fast == [1, 319, 4559, 1243]

fast_tokenizer.add_eos_token = True
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
assert fast == [1, 319, 4559, 1243, 2]

slow_tokenizer.add_eos_token = True
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
assert slow == [1, 319, 4559, 1243, 2]

fast_tokenizer = LlamaTokenizerFast.from_pretrained(
"hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False
)
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
assert fast == [319, 4559, 1243, 2]

slow_tokenzier = LlamaTokenizer.from_pretrained(
"hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False
)
slow = slow_tokenzier.encode("A sample test", add_special_tokens=True)
assert slow == [319, 4559, 1243, 2]

self.tokenizer.add_eos_token = False
self.rust_tokenizer.add_eos_token = False

@slow
def test_conversion(self):
# This is excruciatingly slow since it has to recreate the entire merge
Expand Down