57
57
from axolotl .utils .bench import log_gpu_memory_usage
58
58
from axolotl .utils .chat_templates import get_chat_template_from_config
59
59
from axolotl .utils .dict import DictDefault
60
- from axolotl .utils .distributed import get_device_count , get_device_type , zero_only
60
+ from axolotl .utils .distributed import (
61
+ barrier ,
62
+ get_device_count ,
63
+ get_device_type ,
64
+ is_local_main_process ,
65
+ zero_only ,
66
+ )
61
67
from axolotl .utils .gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
62
68
from axolotl .utils .lora_embeddings import get_linear_embedding_layers
63
69
from axolotl .utils .model_shard_quant import load_sharded_model , load_sharded_model_quant
@@ -165,7 +171,95 @@ def load_model_config(cfg):
165
171
return model_config
166
172
167
173
174
+ def modify_tokenizer_files (
175
+ tokenizer_path : str , token_mappings : Dict [int , str ], output_dir : str
176
+ ) -> str :
177
+ """
178
+ Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer.
179
+
180
+ This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab.
181
+
182
+ Args:
183
+ tokenizer_path: Path or name of the original tokenizer
184
+ token_mappings: Dict mapping {token_id (int): new_token_string}
185
+ output_dir: Directory to save the modified tokenizer
186
+
187
+ Returns:
188
+ Path to the modified tokenizer directory
189
+
190
+ Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
191
+ """
192
+
193
+ import json
194
+
195
+ # Create the tokenizer directory in output_dir if it doesn't exist
196
+ tokenizer_dir = os .path .join (output_dir , "tokenizer" )
197
+ os .makedirs (tokenizer_dir , exist_ok = True )
198
+
199
+ if is_local_main_process (): # pylint: disable=too-many-nested-blocks
200
+ # Load the tokenizer
201
+ temp_tokenizer = AutoTokenizer .from_pretrained (tokenizer_path , use_fast = True )
202
+
203
+ # Save the tokenizer to the output directory
204
+ temp_tokenizer .save_pretrained (tokenizer_dir )
205
+
206
+ # Get the token IDs and map them to their new values
207
+ token_id_mappings = {
208
+ int (token_id ): new_value for token_id , new_value in token_mappings .items ()
209
+ }
210
+
211
+ # 1. Update tokenizer_config.json - added_tokens_decoder
212
+ config_path = os .path .join (tokenizer_dir , "tokenizer_config.json" )
213
+ if os .path .exists (config_path ):
214
+ with open (config_path , "r" , encoding = "utf-8" ) as f :
215
+ config_data = json .load (f )
216
+
217
+ # Update added_tokens_decoder
218
+ if "added_tokens_decoder" in config_data :
219
+ for token_id , new_value in token_id_mappings .items ():
220
+ token_id_str = str (token_id )
221
+ if token_id_str in config_data ["added_tokens_decoder" ]:
222
+ config_data ["added_tokens_decoder" ][token_id_str ][
223
+ "content"
224
+ ] = new_value
225
+ else :
226
+ raise ValueError (
227
+ f"Token ID { token_id_str } not found in added_tokens_decoder"
228
+ )
229
+
230
+ # Write the updated config back
231
+ with open (config_path , "w" , encoding = "utf-8" ) as f :
232
+ json .dump (config_data , f , indent = 2 )
233
+
234
+ # 2. Update tokenizer.json - added_tokens
235
+ tokenizer_path = os .path .join (tokenizer_dir , "tokenizer.json" )
236
+ if os .path .exists (tokenizer_path ):
237
+ with open (tokenizer_path , "r" , encoding = "utf-8" ) as f :
238
+ tokenizer_data = json .load (f )
239
+
240
+ # Update added_tokens
241
+ if "added_tokens" in tokenizer_data :
242
+ for token_id , new_value in token_id_mappings .items ():
243
+ for i , token_entry in enumerate (tokenizer_data ["added_tokens" ]):
244
+ if token_entry ["id" ] == token_id :
245
+ tokenizer_data ["added_tokens" ][i ]["content" ] = new_value
246
+ break
247
+ else :
248
+ # Reaching this section means the token_id was not found in tokenizer.json added_tokens
249
+ raise ValueError (
250
+ f"Token ID { token_id } not found in added_tokens"
251
+ )
252
+
253
+ # Write the updated tokenizer data back
254
+ with open (tokenizer_path , "w" , encoding = "utf-8" ) as f :
255
+ json .dump (tokenizer_data , f , indent = 2 )
256
+
257
+ barrier ()
258
+ return tokenizer_dir
259
+
260
+
168
261
def load_tokenizer (cfg ):
262
+ """Load and configure the tokenizer based on the provided config."""
169
263
model_config = load_model_config (cfg )
170
264
tokenizer_kwargs = {}
171
265
use_fast = True # this is the default
@@ -180,8 +274,18 @@ def load_tokenizer(cfg):
180
274
if cfg .tokenizer_type :
181
275
tokenizer_cls = getattr (transformers , cfg .tokenizer_type )
182
276
277
+ # Set base tokenizer path
278
+ tokenizer_path = cfg .tokenizer_config
279
+
280
+ # Apply token string overrides if specified
281
+ if cfg .added_tokens_overrides :
282
+ # Modify tokenizer files and get path to modified tokenizer
283
+ tokenizer_path = modify_tokenizer_files (
284
+ tokenizer_path , cfg .added_tokens_overrides , output_dir = cfg .output_dir
285
+ )
286
+
183
287
tokenizer = tokenizer_cls .from_pretrained (
184
- cfg . tokenizer_config ,
288
+ tokenizer_path ,
185
289
trust_remote_code = cfg .trust_remote_code or False ,
186
290
use_fast = use_fast ,
187
291
** tokenizer_kwargs ,
0 commit comments