Skip to content

Commit 6a02e98

Browse files
authored
LlamaTokenizerFast Fix (.., from_slow=True). (#22630)
1 parent 09a9888 commit 6a02e98

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

src/transformers/models/llama/tokenization_llama_fast.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,25 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import os
16+
from shutil import copyfile
17+
from typing import Optional, Tuple
18+
1519
from ...tokenization_utils_fast import PreTrainedTokenizerFast
20+
from ...utils import is_sentencepiece_available, logging
1621
from ...utils.versions import require_version
1722

1823

1924
require_version("tokenizers>=0.13.3")
2025

26+
if is_sentencepiece_available():
27+
from .tokenization_llama import LlamaTokenizer
28+
else:
29+
LlamaTokenizer = None
30+
31+
logger = logging.get_logger(__name__)
32+
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
33+
2134

2235
class LlamaTokenizerFast(PreTrainedTokenizerFast):
2336
"""
@@ -59,6 +72,8 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
5972
token instead.
6073
"""
6174

75+
vocab_files_names = VOCAB_FILES_NAMES
76+
slow_tokenizer_class = LlamaTokenizer
6277
padding_side = "left"
6378

6479
def __init__(
@@ -80,3 +95,25 @@ def __init__(
8095
eos_token=eos_token,
8196
**kwargs,
8297
)
98+
99+
self.vocab_file = vocab_file
100+
self.can_save_slow_tokenizer = False if not self.vocab_file else True
101+
102+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
103+
if not self.can_save_slow_tokenizer:
104+
raise ValueError(
105+
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
106+
"tokenizer."
107+
)
108+
109+
if not os.path.isdir(save_directory):
110+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
111+
return
112+
out_vocab_file = os.path.join(
113+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
114+
)
115+
116+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
117+
copyfile(self.vocab_file, out_vocab_file)
118+
119+
return (out_vocab_file,)

0 commit comments

Comments
 (0)