Skip to content

Commit 575e5f2

Browse files
mhenrichsenNanoCode012winglian
authored
Update Tokenizer Overrides Handling in models.py (axolotl-ai-cloud#1549)
* override special tokens mock code * fix(doc): remove duplicate config * feat: replace added_tokens in tokenizer and add test * make sure to run tokenizer modification on rank 0 only * use is local main process instead * feat: rename config --------- Co-authored-by: NanoCode012 <nano@axolotl.ai> Co-authored-by: Wing Lian <wing@axolotl.ai>
1 parent 0134093 commit 575e5f2

File tree

5 files changed

+156
-9
lines changed

5 files changed

+156
-9
lines changed

docs/config.qmd

+7-2
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,6 @@ datasets:
154154
content: value
155155
# ...
156156

157-
message_property_mappings:
158-
159157
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
160158
roles:
161159
user: ["human", "user"]
@@ -556,6 +554,13 @@ special_tokens:
556554
# Add extra tokens.
557555
tokens:
558556

557+
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
558+
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
559+
# Can be checked if they exist in tokenizer.json added_tokens.
560+
added_tokens_overrides: # Dict[int, str]
561+
# 128041: "<|im_start|>"
562+
# 128042: "<|im_end|>"
563+
559564
# FSDP
560565
fsdp:
561566
fsdp_config:

src/axolotl/utils/config/models/input/v0_4_1/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,7 @@ class AxolotlInputConfig(
855855

856856
special_tokens: Optional[SpecialTokensConfig] = None
857857
tokens: Optional[List[str]] = None
858+
added_tokens_overrides: Optional[Dict[int, str]] = None
858859

859860
torch_compile: Optional[Union[Literal["auto"], bool]] = None
860861
torch_compile_backend: Optional[str] = None

src/axolotl/utils/distributed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def is_main_process():
7979

8080

8181
def is_local_main_process():
82-
return PartialState().is_main_process
82+
return PartialState().is_local_main_process
8383

8484

8585
def get_world_size():

src/axolotl/utils/models.py

+106-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,13 @@
5757
from axolotl.utils.bench import log_gpu_memory_usage
5858
from axolotl.utils.chat_templates import get_chat_template_from_config
5959
from axolotl.utils.dict import DictDefault
60-
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
60+
from axolotl.utils.distributed import (
61+
barrier,
62+
get_device_count,
63+
get_device_type,
64+
is_local_main_process,
65+
zero_only,
66+
)
6167
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
6268
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
6369
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -165,7 +171,95 @@ def load_model_config(cfg):
165171
return model_config
166172

167173

174+
def modify_tokenizer_files(
175+
tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
176+
) -> str:
177+
"""
178+
Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer.
179+
180+
This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab.
181+
182+
Args:
183+
tokenizer_path: Path or name of the original tokenizer
184+
token_mappings: Dict mapping {token_id (int): new_token_string}
185+
output_dir: Directory to save the modified tokenizer
186+
187+
Returns:
188+
Path to the modified tokenizer directory
189+
190+
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
191+
"""
192+
193+
import json
194+
195+
# Create the tokenizer directory in output_dir if it doesn't exist
196+
tokenizer_dir = os.path.join(output_dir, "tokenizer")
197+
os.makedirs(tokenizer_dir, exist_ok=True)
198+
199+
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
200+
# Load the tokenizer
201+
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
202+
203+
# Save the tokenizer to the output directory
204+
temp_tokenizer.save_pretrained(tokenizer_dir)
205+
206+
# Get the token IDs and map them to their new values
207+
token_id_mappings = {
208+
int(token_id): new_value for token_id, new_value in token_mappings.items()
209+
}
210+
211+
# 1. Update tokenizer_config.json - added_tokens_decoder
212+
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
213+
if os.path.exists(config_path):
214+
with open(config_path, "r", encoding="utf-8") as f:
215+
config_data = json.load(f)
216+
217+
# Update added_tokens_decoder
218+
if "added_tokens_decoder" in config_data:
219+
for token_id, new_value in token_id_mappings.items():
220+
token_id_str = str(token_id)
221+
if token_id_str in config_data["added_tokens_decoder"]:
222+
config_data["added_tokens_decoder"][token_id_str][
223+
"content"
224+
] = new_value
225+
else:
226+
raise ValueError(
227+
f"Token ID {token_id_str} not found in added_tokens_decoder"
228+
)
229+
230+
# Write the updated config back
231+
with open(config_path, "w", encoding="utf-8") as f:
232+
json.dump(config_data, f, indent=2)
233+
234+
# 2. Update tokenizer.json - added_tokens
235+
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
236+
if os.path.exists(tokenizer_path):
237+
with open(tokenizer_path, "r", encoding="utf-8") as f:
238+
tokenizer_data = json.load(f)
239+
240+
# Update added_tokens
241+
if "added_tokens" in tokenizer_data:
242+
for token_id, new_value in token_id_mappings.items():
243+
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
244+
if token_entry["id"] == token_id:
245+
tokenizer_data["added_tokens"][i]["content"] = new_value
246+
break
247+
else:
248+
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
249+
raise ValueError(
250+
f"Token ID {token_id} not found in added_tokens"
251+
)
252+
253+
# Write the updated tokenizer data back
254+
with open(tokenizer_path, "w", encoding="utf-8") as f:
255+
json.dump(tokenizer_data, f, indent=2)
256+
257+
barrier()
258+
return tokenizer_dir
259+
260+
168261
def load_tokenizer(cfg):
262+
"""Load and configure the tokenizer based on the provided config."""
169263
model_config = load_model_config(cfg)
170264
tokenizer_kwargs = {}
171265
use_fast = True # this is the default
@@ -180,8 +274,18 @@ def load_tokenizer(cfg):
180274
if cfg.tokenizer_type:
181275
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
182276

277+
# Set base tokenizer path
278+
tokenizer_path = cfg.tokenizer_config
279+
280+
# Apply token string overrides if specified
281+
if cfg.added_tokens_overrides:
282+
# Modify tokenizer files and get path to modified tokenizer
283+
tokenizer_path = modify_tokenizer_files(
284+
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
285+
)
286+
183287
tokenizer = tokenizer_cls.from_pretrained(
184-
cfg.tokenizer_config,
288+
tokenizer_path,
185289
trust_remote_code=cfg.trust_remote_code or False,
186290
use_fast=use_fast,
187291
**tokenizer_kwargs,

tests/test_tokenizers.py

+41-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Test cases for the tokenizer loading
33
"""
4+
45
import unittest
56

67
import pytest
@@ -9,7 +10,7 @@
910
from axolotl.utils.models import load_tokenizer
1011

1112

12-
class TestTokenizers(unittest.TestCase):
13+
class TestTokenizers:
1314
"""
1415
test class for the load_tokenizer fn
1516
"""
@@ -75,12 +76,48 @@ def test_add_additional_special_tokens(self):
7576
}
7677
)
7778
tokenizer = load_tokenizer(cfg)
78-
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
79-
self.assertEqual(len(tokenizer), 32001)
79+
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
80+
assert len(tokenizer) == 32001
8081

8182
# ensure reloading the tokenizer again from cfg results in same vocab length
8283
tokenizer = load_tokenizer(cfg)
83-
self.assertEqual(len(tokenizer), 32001)
84+
assert len(tokenizer) == 32001
85+
86+
def test_added_tokens_overrides(self, temp_dir):
87+
cfg = DictDefault(
88+
{
89+
# use with tokenizer that has reserved_tokens in added_tokens
90+
"tokenizer_config": "NousResearch/Llama-3.2-1B",
91+
"added_tokens_overrides": {
92+
128041: "RANDOM_OVERRIDE_1",
93+
128042: "RANDOM_OVERRIDE_2",
94+
},
95+
"output_dir": temp_dir,
96+
}
97+
)
98+
99+
tokenizer = load_tokenizer(cfg)
100+
assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [
101+
128041
102+
]
103+
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
104+
128042
105+
]
106+
107+
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
108+
cfg = DictDefault(
109+
{
110+
# use with tokenizer that has reserved_tokens in added_tokens
111+
"tokenizer_config": "NousResearch/Llama-3.2-1B",
112+
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
113+
"output_dir": temp_dir,
114+
}
115+
)
116+
117+
with pytest.raises(
118+
ValueError, match=r".*Token ID 1000000 not found in added_tokens.*"
119+
):
120+
load_tokenizer(cfg)
84121

85122

86123
if __name__ == "__main__":

0 commit comments

Comments
 (0)