1+ import bisect
12import os
23import sentencepiece as spm
34import tiktoken
@@ -21,6 +22,15 @@ def bos_id(self):
2122 def eos_id (self ):
2223 raise NotImplementedError ("This method should be overridden by subclasses." )
2324
25+ def id_to_piece (self , token_id ):
26+ raise NotImplementedError ("This method should be overridden by subclasses." )
27+
28+ def piece_to_id (self , token_str ):
29+ raise NotImplementedError ("This method should be overridden by subclasses." )
30+
31+ def is_special_token (self , token_id ):
32+ raise NotImplementedError ("This method should be overridden by subclasses." )
33+
2434class SentencePieceWrapper (TokenizerInterface ):
2535 def __init__ (self , model_path ):
2636 super ().__init__ (model_path )
@@ -38,6 +48,17 @@ def bos_id(self):
3848 def eos_id (self ):
3949 return self .processor .eos_id ()
4050
51+ def id_to_piece (self , token_id ):
52+ return self .processor .id_to_piece (token_id ).replace ("▁" , " " )
53+
54+ def piece_to_id (self , token_str ):
55+ return self .processor .piece_to_id (token_str .replace (" " , "▁" ))
56+
57+ def is_special_token (self , token_id ):
58+ return self .processor .IsControl (token_id ) \
59+ or self .processor .IsUnknown (token_id ) \
60+ or self .processor .IsUnused (token_id )
61+
4162class TiktokenWrapper (TokenizerInterface ):
4263 """
4364 Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
@@ -53,7 +74,7 @@ def __init__(self, model_path):
5374 super ().__init__ (model_path )
5475 assert os .path .isfile (model_path ), str (model_path )
5576 mergeable_ranks = load_tiktoken_bpe (str (model_path ))
56- num_base_tokens = len (mergeable_ranks )
77+ self . num_base_tokens = len (mergeable_ranks )
5778 special_tokens = [
5879 "<|begin_of_text|>" ,
5980 "<|end_of_text|>" ,
@@ -70,7 +91,7 @@ def __init__(self, model_path):
7091 for i in range (5 , self .num_reserved_special_tokens - 5 )
7192 ]
7293 self .special_tokens = {
73- token : num_base_tokens + i for i , token in enumerate (special_tokens )
94+ token : self . num_base_tokens + i for i , token in enumerate (special_tokens )
7495 }
7596 self .model = tiktoken .Encoding (
7697 name = Path (model_path ).name ,
@@ -94,6 +115,15 @@ def bos_id(self):
94115 def eos_id (self ):
95116 return self ._eos_id
96117
118+ def id_to_piece (self , token_id ):
119+ return self .model .decode ([token_id ])
120+
121+ def piece_to_id (self , token_str ):
122+ return self .model .encode_single_token (token_str )
123+
124+ def is_special_token (self , token_id ):
125+ return token_id >= self .num_base_tokens and token_id < self .num_base_tokens + len (self .special_tokens )
126+
97127def get_tokenizer (tokenizer_model_path , model_name ):
98128 """
99129 Factory function to get the appropriate tokenizer based on the model name.
0 commit comments