Skip to content

Commit 7350035

Browse files
committed
Add saving in the new format (but no loading yet!)
1 parent cbe0ea5 commit 7350035

File tree

1 file changed

+37
-12
lines changed

1 file changed

+37
-12
lines changed

src/transformers/tokenization_utils_base.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import copy
2222
import json
23+
from pathlib import Path
2324
import os
2425
import re
2526
import warnings
@@ -146,6 +147,7 @@ class EncodingFast:
146147
ADDED_TOKENS_FILE = "added_tokens.json"
147148
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
148149
CHAT_TEMPLATE_FILE = "chat_template.jinja"
150+
CHAT_TEMPLATE_DIR = "chat_templates"
149151

150152
# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
151153
FULL_TOKENIZER_FILE = "tokenizer.json"
@@ -2420,6 +2422,9 @@ def save_pretrained(
24202422
chat_template_file = os.path.join(
24212423
save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_FILE
24222424
)
2425+
chat_template_dir = os.path.join(
2426+
save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_DIR
2427+
)
24232428

24242429
tokenizer_config = copy.deepcopy(self.init_kwargs)
24252430

@@ -2438,22 +2443,44 @@ def save_pretrained(
24382443
tokenizer_config["extra_special_tokens"] = self.extra_special_tokens
24392444
tokenizer_config.update(self.extra_special_tokens)
24402445

2441-
saved_raw_chat_template = False
2446+
saved_raw_chat_template_files = []
24422447
if self.chat_template is not None:
2443-
if isinstance(self.chat_template, dict):
2444-
# Chat template dicts are saved to the config as lists of dicts with fixed key names.
2445-
# They will be reconstructed as a single dict during loading.
2446-
# We're trying to discourage chat template dicts, and they are always
2447-
# saved in the config, never as single files.
2448-
tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()]
2449-
elif kwargs.get("save_raw_chat_template", False):
2448+
if kwargs.get("save_raw_chat_template", False) and isinstance(self.chat_template, str):
2449+
# New format for single templates is to save them as chat_template.jinja
24502450
with open(chat_template_file, "w", encoding="utf-8") as f:
24512451
f.write(self.chat_template)
2452-
saved_raw_chat_template = True
24532452
logger.info(f"chat template saved in {chat_template_file}")
2453+
saved_raw_chat_template_files.append(chat_template_file)
2454+
if "chat_template" in tokenizer_config:
2455+
tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too
2456+
elif kwargs.get("save_raw_chat_template", False) and isinstance(self.chat_template, dict):
2457+
# New format for multiple templates is to save the default as chat_template.jinja
2458+
# and the other templates in the chat_templates/ directory
2459+
for template_name, template in self.chat_template.items():
2460+
if template_name == "default":
2461+
with open(chat_template_file, "w", encoding="utf-8") as f:
2462+
f.write(self.chat_template["default"])
2463+
logger.info(f"chat template saved in {chat_template_file}")
2464+
saved_raw_chat_template_files.append(chat_template_file)
2465+
else:
2466+
Path(chat_template_dir).mkdir(exist_ok=True)
2467+
template_filepath = os.path.join(chat_template_dir, f"{template_name}.jinja")
2468+
with open(template_filepath, "w", encoding="utf-8") as f:
2469+
f.write(template)
2470+
logger.info(f"chat template saved in {template_filepath}")
2471+
saved_raw_chat_template_files.append(template_filepath)
2472+
saved_raw_chat_template = True
24542473
if "chat_template" in tokenizer_config:
24552474
tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too
2475+
elif isinstance(self.chat_template, dict):
2476+
# Legacy format for multiple templates:
2477+
# chat template dicts are saved to the config as lists of dicts with fixed key names.
2478+
# They will be reconstructed as a single dict during loading.
2479+
# We're trying to discourage chat template dicts, and they are always
2480+
# saved in the config, never as single files.
2481+
tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()]
24562482
else:
2483+
# Legacy format for single templates: Just make them a key in tokenizer_config.json
24572484
tokenizer_config["chat_template"] = self.chat_template
24582485

24592486
if len(self.init_inputs) > 0:
@@ -2508,9 +2535,7 @@ def save_pretrained(
25082535
f.write(out_str)
25092536
logger.info(f"Special tokens file saved in {special_tokens_map_file}")
25102537

2511-
file_names = (tokenizer_config_file, special_tokens_map_file)
2512-
if saved_raw_chat_template:
2513-
file_names += (chat_template_file,)
2538+
file_names = (tokenizer_config_file, special_tokens_map_file, *saved_raw_chat_template_files)
25142539

25152540
save_files = self._save_pretrained(
25162541
save_directory=save_directory,

0 commit comments

Comments
 (0)