1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15+ import os
16+ from shutil import copyfile
17+ from typing import Optional , Tuple
18+
1519from ...tokenization_utils_fast import PreTrainedTokenizerFast
20+ from ...utils import is_sentencepiece_available , logging
1621from ...utils .versions import require_version
1722
1823
1924require_version ("tokenizers>=0.13.3" )
2025
26+ if is_sentencepiece_available ():
27+ from .tokenization_llama import LlamaTokenizer
28+ else :
29+ LlamaTokenizer = None
30+
31+ logger = logging .get_logger (__name__ )
32+ VOCAB_FILES_NAMES = {"vocab_file" : "tokenizer.model" , "tokenizer_file" : "tokenizer.json" }
33+
2134
2235class LlamaTokenizerFast (PreTrainedTokenizerFast ):
2336 """
@@ -59,6 +72,8 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
5972 token instead.
6073 """
6174
75+ vocab_files_names = VOCAB_FILES_NAMES
76+ slow_tokenizer_class = LlamaTokenizer
6277 padding_side = "left"
6378
6479 def __init__ (
@@ -80,3 +95,25 @@ def __init__(
8095 eos_token = eos_token ,
8196 ** kwargs ,
8297 )
98+
99+ self .vocab_file = vocab_file
100+ self .can_save_slow_tokenizer = False if not self .vocab_file else True
101+
102+ def save_vocabulary (self , save_directory : str , filename_prefix : Optional [str ] = None ) -> Tuple [str ]:
103+ if not self .can_save_slow_tokenizer :
104+ raise ValueError (
105+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
106+ "tokenizer."
107+ )
108+
109+ if not os .path .isdir (save_directory ):
110+ logger .error (f"Vocabulary path ({ save_directory } ) should be a directory" )
111+ return
112+ out_vocab_file = os .path .join (
113+ save_directory , (filename_prefix + "-" if filename_prefix else "" ) + VOCAB_FILES_NAMES ["vocab_file" ]
114+ )
115+
116+ if os .path .abspath (self .vocab_file ) != os .path .abspath (out_vocab_file ):
117+ copyfile (self .vocab_file , out_vocab_file )
118+
119+ return (out_vocab_file ,)
0 commit comments