@@ -2,8 +2,11 @@ use fancy_regex::Regex;
22use mlua:: prelude:: * ;
33use rustc_hash:: FxHashMap as HashMap ;
44use std:: collections:: HashSet ;
5+ use std:: fs:: File ;
6+ use std:: io:: { BufRead , BufReader } ;
57use std:: sync:: { Arc , Mutex } ;
68use std:: thread;
9+ use base64;
710
811#[ cfg( feature = "multithreading" ) ]
912const MAX_NUM_THREADS : usize = 128 ;
@@ -191,12 +194,12 @@ pub fn tiktoken_core(lua: &mlua::Lua) -> LuaResult<LuaTable> {
191194
192195 let _new = lua. create_function (
193196 move |_,
194- ( encoder , special_tokens_encoder, pattern) : (
195- HashMap < LuaString , usize > ,
197+ ( encoder_path , special_tokens_encoder, pattern) : (
198+ String ,
196199 HashMap < String , usize > ,
197200 String ,
198201 ) | {
199- new ( & * state, encoder , special_tokens_encoder, pattern) ;
202+ new ( & * state, encoder_path , special_tokens_encoder, pattern) ;
200203 Ok ( ( ) )
201204 } ,
202205 ) ?;
@@ -210,14 +213,21 @@ pub fn tiktoken_core(lua: &mlua::Lua) -> LuaResult<LuaTable> {
210213
211214fn new (
212215 state : & State ,
213- iencoder : HashMap < LuaString , usize > ,
216+ encoder_path : String ,
214217 special_tokens_encoder : HashMap < String , usize > ,
215218 pattern : String ,
216219) {
217- let encoder: HashMap < Vec < u8 > , usize > = iencoder
218- . into_iter ( )
219- . map ( |( k, v) | ( k. as_bytes ( ) . to_vec ( ) , v) )
220- . collect ( ) ;
220+ let mut encoder: HashMap < Vec < u8 > , usize > = HashMap :: default ( ) ;
221+ // Read the encoder file each line is a base64 encoded token and rank separated by a space
222+ let file = File :: open ( encoder_path) . unwrap ( ) ;
223+ let reader = BufReader :: new ( file) ;
224+ for line in reader. lines ( ) {
225+ let line = line. unwrap ( ) ;
226+ let mut parts = line. split_whitespace ( ) ;
227+ let token = base64:: decode ( parts. next ( ) . unwrap ( ) . as_bytes ( ) ) . unwrap ( ) ;
228+ let rank = parts. next ( ) . unwrap ( ) . parse ( ) . unwrap ( ) ;
229+ encoder. insert ( token, rank) ;
230+ }
221231 let regex = Regex :: new ( & pattern)
222232 . map_err ( |e| mlua:: Error :: external ( e) )
223233 . unwrap ( ) ;
@@ -230,11 +240,6 @@ fn new(
230240 . map_err ( |e| mlua:: Error :: external ( e) )
231241 . unwrap ( )
232242 } ;
233- let decoder: HashMap < usize , Vec < u8 > > = encoder. iter ( ) . map ( |( k, v) | ( * v, k. clone ( ) ) ) . collect ( ) ;
234- assert ! (
235- encoder. len( ) == decoder. len( ) ,
236- "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
237- ) ;
238243 let special_tokens_decoder: HashMap < usize , Vec < u8 > > = special_tokens_encoder
239244 . iter ( )
240245 . map ( |( k, v) | ( * v, k. as_bytes ( ) . to_vec ( ) ) )
@@ -245,7 +250,8 @@ fn new(
245250 * core_bpe_lock = Some ( CoreBPENative {
246251 encoder,
247252 special_tokens_encoder,
248- decoder,
253+ // empty decoder
254+ decoder : HashMap :: default ( ) ,
249255 special_tokens_decoder,
250256 regex_tls : ( 0 ..MAX_NUM_THREADS ) . map ( |_| regex. clone ( ) ) . collect ( ) ,
251257 special_regex_tls : ( 0 ..MAX_NUM_THREADS )
0 commit comments