diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index b5b02f6a00aa09..c2646300367033 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -44,7 +44,6 @@ TruncationStrategy, ) from .utils import ( - CHAT_TEMPLATE_NAME, PROCESSOR_NAME, PushToHubMixin, TensorType, @@ -527,18 +526,24 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs): # If we save using the predefined names, we can load using `from_pretrained` # plus we save chat_template in its own file output_processor_file = os.path.join(save_directory, PROCESSOR_NAME) - output_chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_NAME) + output_raw_chat_template_file = os.path.join(save_directory, "chat_template.jinja") + output_chat_template_file = os.path.join(save_directory, "chat_template.json") processor_dict = self.to_dict() # Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict` # to avoid serializing chat template in json config file. So let's get it from `self` directly if self.chat_template is not None: - chat_template_json_string = ( - json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n" - ) - with open(output_chat_template_file, "w", encoding="utf-8") as writer: - writer.write(chat_template_json_string) - logger.info(f"chat template saved in {output_chat_template_file}") + if kwargs.get("save_raw_chat_template", False): + with open(output_raw_chat_template_file, "w", encoding="utf-8") as writer: + writer.write(self.chat_template) + logger.info(f"chat template saved in {output_raw_chat_template_file}") + else: + chat_template_json_string = ( + json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n" + ) + with open(output_chat_template_file, "w", encoding="utf-8") as writer: + writer.write(chat_template_json_string) + logger.info(f"chat template saved in {output_chat_template_file}") # For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and # `auto_map` is not specified. @@ -601,21 +606,23 @@ def get_processor_dict( is_local = os.path.isdir(pretrained_model_name_or_path) if os.path.isdir(pretrained_model_name_or_path): processor_file = os.path.join(pretrained_model_name_or_path, PROCESSOR_NAME) - chat_template_file = os.path.join(pretrained_model_name_or_path, "chat_template.json") if os.path.isfile(pretrained_model_name_or_path): resolved_processor_file = pretrained_model_name_or_path # cant't load chat-template when given a file as pretrained_model_name_or_path resolved_chat_template_file = None + resolved_raw_chat_template_file = None is_local = True elif is_remote_url(pretrained_model_name_or_path): processor_file = pretrained_model_name_or_path resolved_processor_file = download_url(pretrained_model_name_or_path) # can't load chat-template when given a file url as pretrained_model_name_or_path resolved_chat_template_file = None + resolved_raw_chat_template_file = None else: processor_file = PROCESSOR_NAME - chat_template_file = CHAT_TEMPLATE_NAME + chat_template_file = "chat_template.json" + raw_chat_template_file = "chat_template.jinja" try: # Load from local folder or from cache or download from model Hub and cache resolved_processor_file = cached_file( @@ -650,6 +657,21 @@ def get_processor_dict( subfolder=subfolder, _raise_exceptions_for_missing_entries=False, ) + + resolved_raw_chat_template_file = cached_file( + pretrained_model_name_or_path, + raw_chat_template_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) except EnvironmentError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # the original exception. @@ -664,8 +686,11 @@ def get_processor_dict( ) # Add chat template as kwarg before returning because most models don't have processor config - chat_template = None - if resolved_chat_template_file is not None: + if resolved_raw_chat_template_file is not None: + with open(resolved_raw_chat_template_file, "r", encoding="utf-8") as reader: + chat_template = reader.read() + kwargs["chat_template"] = chat_template + elif resolved_chat_template_file is not None: with open(resolved_chat_template_file, "r", encoding="utf-8") as reader: text = reader.read() chat_template = json.loads(text)["chat_template"] @@ -696,7 +721,7 @@ def get_processor_dict( if "chat_template" in processor_dict and processor_dict["chat_template"] is not None: logger.warning_once( - "Chat templates should be in a 'chat_template.json' file but found key='chat_template' " + "Chat templates should be in a 'chat_template.jinja' file but found key='chat_template' " "in the processor's config. Make sure to move your template to its own file." ) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index ca5a3bb9c20412..0bfcc4aa303665 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -145,6 +145,7 @@ class EncodingFast: SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" ADDED_TOKENS_FILE = "added_tokens.json" TOKENIZER_CONFIG_FILE = "tokenizer_config.json" +CHAT_TEMPLATE_FILE = "chat_template.jinja" # Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file FULL_TOKENIZER_FILE = "tokenizer.json" @@ -1941,6 +1942,7 @@ def from_pretrained( "tokenizer_config_file": TOKENIZER_CONFIG_FILE, # tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders "tokenizer_file": FULL_TOKENIZER_FILE, + "chat_template_file": CHAT_TEMPLATE_FILE, } vocab_files = {**cls.vocab_files_names, **additional_files_names} if "tokenizer_file" in vocab_files: @@ -2097,6 +2099,12 @@ def _from_pretrained( config_tokenizer_class = None init_kwargs = init_configuration + # If an independent chat template file exists, it takes priority over template entries in the tokenizer config + chat_template_file = resolved_vocab_files.pop("chat_template_file", None) + if chat_template_file is not None: + with open(chat_template_file) as chat_template_handle: + init_kwargs["chat_template"] = chat_template_handle.read() # Clobbers any template in the config + if not _is_local: if "auto_map" in init_kwargs: # For backward compatibility with odl format. @@ -2396,6 +2404,9 @@ def save_pretrained( tokenizer_config_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE ) + chat_template_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_FILE + ) tokenizer_config = copy.deepcopy(self.init_kwargs) @@ -2418,7 +2429,15 @@ def save_pretrained( if isinstance(self.chat_template, dict): # Chat template dicts are saved to the config as lists of dicts with fixed key names. # They will be reconstructed as a single dict during loading. + # We're trying to discourage chat template dicts, and they are always + # saved in the config, never as single files. tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()] + elif kwargs.get("save_raw_chat_template", False): + with open(chat_template_file, "w", encoding="utf-8") as f: + f.write(self.chat_template) + logger.info(f"chat template saved in {chat_template_file}") + if "chat_template" in tokenizer_config: + tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too else: tokenizer_config["chat_template"] = self.chat_template diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index 9f0d88089129b8..4c4f6fac49813f 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -18,6 +18,7 @@ import json import random import tempfile +from pathlib import Path from typing import Optional import numpy as np @@ -519,3 +520,27 @@ def test_prepare_and_validate_optional_call_args(self): processor.prepare_and_validate_optional_call_args( *(f"optional_{i}" for i in range(num_optional_call_args + 1)) ) + + def test_chat_template_save_loading(self): + processor = self.get_processor() + existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None) + processor.chat_template = "test template" + with tempfile.TemporaryDirectory() as tmpdirname: + processor.save_pretrained(tmpdirname) + self.assertTrue(Path(tmpdirname, "chat_template.json").is_file()) + self.assertFalse(Path(tmpdirname, "chat_template.jinja").is_file()) + reloaded_processor = self.processor_class.from_pretrained(tmpdirname) + self.assertEqual(processor.chat_template, reloaded_processor.chat_template) + # When we don't use single-file chat template saving, processor and tokenizer chat templates + # should remain separate + self.assertEqual(getattr(reloaded_processor.tokenizer, "chat_template", None), existing_tokenizer_template) + + with tempfile.TemporaryDirectory() as tmpdirname: + processor.save_pretrained(tmpdirname, save_raw_chat_template=True) + self.assertTrue(Path(tmpdirname, "chat_template.jinja").is_file()) + self.assertFalse(Path(tmpdirname, "chat_template.json").is_file()) + reloaded_processor = self.processor_class.from_pretrained(tmpdirname) + self.assertEqual(processor.chat_template, reloaded_processor.chat_template) + # When we save as single files, tokenizers and processors share a chat template, which means + # the reloaded tokenizer should get the chat template as well + self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index f04a4255556baf..ed09d800ad6dd5 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -25,6 +25,7 @@ import unittest from collections import OrderedDict from itertools import takewhile +from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union from parameterized import parameterized @@ -1107,13 +1108,29 @@ def test_chat_template(self): with tempfile.TemporaryDirectory() as tmp_dir_name: tokenizer.save_pretrained(tmp_dir_name) - tokenizer = tokenizer.from_pretrained(tmp_dir_name) + new_tokenizer = tokenizer.from_pretrained(tmp_dir_name) - self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted - output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False) + self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted + output = new_tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False) self.assertEqual(output, expected_output) # Test output is the same after reloading # Check that no error raised - tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False) + new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False) + + with tempfile.TemporaryDirectory() as tmp_dir_name: + tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True) + chat_template_file = Path(tmp_dir_name) / "chat_template.jinja" + self.assertTrue(chat_template_file.is_file()) + self.assertEqual(chat_template_file.read_text(), dummy_template) + config_dict = json.loads((Path(tmp_dir_name) / "tokenizer_config.json").read_text()) + # Assert the chat template is not in the config when it's saved as a separate file + self.assertNotIn("chat_template", config_dict) + new_tokenizer = tokenizer.from_pretrained(tmp_dir_name) + + self.assertEqual(new_tokenizer.chat_template, dummy_template) # Test template has persisted + output = new_tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False) + self.assertEqual(output, expected_output) # Test output is the same after reloading + # Check that no error raised + new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False) @require_jinja def test_chat_template_batched(self): @@ -1526,18 +1543,40 @@ def test_chat_template_dict_saving(self): tokenizers = self.get_tokenizers() for tokenizer in tokenizers: with self.subTest(f"{tokenizer.__class__.__name__}"): - tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2} + for save_raw_chat_template in (True, False): + tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2} + with tempfile.TemporaryDirectory() as tmp_dir_name: + # Test that save_raw_chat_template is ignored when there's a dict of multiple templates + tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=save_raw_chat_template) + config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json"))) + # Assert that chat templates are correctly serialized as lists of dictionaries + self.assertEqual( + config_dict["chat_template"], + [ + {"name": "template1", "template": "{{'a'}}"}, + {"name": "template2", "template": "{{'b'}}"}, + ], + ) + self.assertFalse(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja"))) + new_tokenizer = tokenizer.from_pretrained(tmp_dir_name) + # Assert that the serialized list is correctly reconstructed as a single dict + self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template) + + @require_jinja + def test_chat_template_file_priority(self): + dummy_template1 = "a" + dummy_template2 = "b" + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): with tempfile.TemporaryDirectory() as tmp_dir_name: - tokenizer.save_pretrained(tmp_dir_name) - config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json"))) - # Assert that chat templates are correctly serialized as lists of dictionaries - self.assertEqual( - config_dict["chat_template"], - [{"name": "template1", "template": "{{'a'}}"}, {"name": "template2", "template": "{{'b'}}"}], - ) + tokenizer.chat_template = dummy_template1 + tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=False) + with Path(tmp_dir_name, "chat_template.jinja").open("w") as f: + f.write(dummy_template2) new_tokenizer = tokenizer.from_pretrained(tmp_dir_name) - # Assert that the serialized list is correctly reconstructed as a single dict - self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template) + # Assert the file template clobbers any template in the config + self.assertEqual(new_tokenizer.chat_template, dummy_template2) def test_number_of_added_tokens(self): tokenizers = self.get_tokenizers(do_lower_case=False)