Skip to content

Commit 7ab1be6

Browse files
Rocketknight1Wauplinjulien-c
authored andcommitted
🚨 🚨 Allow saving and loading multiple "raw" chat template files (huggingface#36588)
* Add saving in the new format (but no loading yet!) * Add saving in the new format (but no loading yet!) * A new approach to template files! * make fixup * make fixup, set correct dir * Some progress but need to rework for cached_file * Rework loading handling again * Small fixes * Looks like it's working now! * make fixup * Working! * make fixup * make fixup * Add TODO so I don't miss it * Cleaner control flow with one less indent * Copy the new logic to processing_utils as well * Proper support for dicts of templates * make fixup * define the file/dir names in a single place * Update the processor chat template reload test as well * Add processor loading of multiple templates * Flatten correctly to match tokenizers * Better support when files are empty sometimes * Stop creating those empty templates * Revert changes now we don't have empty templates * Revert changes now we don't have empty templates * Don't support separate template files on the legacy path * Rework/simplify loading code * Make sure it's always a chat_template key in chat_template.json * Update processor handling of multiple templates * Add a full save-loading test to the tokenizer tests as well * Correct un-flattening * New test was incorrect * Correct error/offline handling * Better exception handling * More error handling cleanup * Add skips for test failing on main * Reorder to fix errors * make fixup * clarify legacy processor file docs and location * Update src/transformers/processing_utils.py Co-authored-by: Lucain <lucainp@gmail.com> * Update src/transformers/processing_utils.py Co-authored-by: Lucain <lucainp@gmail.com> * Update src/transformers/processing_utils.py Co-authored-by: Lucain <lucainp@gmail.com> * Update src/transformers/processing_utils.py Co-authored-by: Lucain <lucainp@gmail.com> * Rename to _jinja and _legacy * Stop saving multiple templates in the legacy format * Cleanup the processing code * Cleanup the processing code more * make fixup * make fixup * correct reformatting * Use correct dir name * Fix import location * Use save_jinja_files instead of save_raw_chat_template_files * Correct the test for saving multiple processor templates * Fix type hint * Update src/transformers/utils/hub.py Co-authored-by: Julien Chaumond <julien@huggingface.co> * Patch llava_onevision test * Update src/transformers/processing_utils.py Co-authored-by: Julien Chaumond <julien@huggingface.co> * Update src/transformers/tokenization_utils_base.py Co-authored-by: Julien Chaumond <julien@huggingface.co> * Refactor chat template saving out into a separate function * Update tests for the new default * Don't do chat template saving logic when chat template isn't there * Ensure save_jinja_files is propagated to tokenizer correctly * Trigger tests * Update more tests to new default * Trigger tests --------- Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: Julien Chaumond <julien@huggingface.co>
1 parent 408ad40 commit 7ab1be6

File tree

9 files changed

+391
-82
lines changed

9 files changed

+391
-82
lines changed

src/transformers/models/llava_onevision/processing_llava_onevision.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -298,13 +298,14 @@ def save_pretrained(self, save_directory, **kwargs):
298298
self.video_processor.save_pretrained(video_processor_path)
299299

300300
video_processor_present = "video_processor" in self.attributes
301-
if video_processor_present:
302-
self.attributes.remove("video_processor")
303-
304-
outputs = super().save_pretrained(save_directory, **kwargs)
301+
try:
302+
if video_processor_present:
303+
self.attributes.remove("video_processor")
305304

306-
if video_processor_present:
307-
self.attributes += ["video_processor"]
305+
outputs = super().save_pretrained(save_directory, **kwargs)
306+
finally:
307+
if video_processor_present:
308+
self.attributes += ["video_processor"]
308309
return outputs
309310

310311
# override to load video-config from a separate config file

src/transformers/processing_utils.py

Lines changed: 132 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import numpy as np
2929
import typing_extensions
30+
from huggingface_hub.errors import EntryNotFoundError
3031

3132
from .audio_utils import load_audio
3233
from .dynamic_module_utils import custom_object_save
@@ -52,6 +53,9 @@
5253
TruncationStrategy,
5354
)
5455
from .utils import (
56+
CHAT_TEMPLATE_DIR,
57+
CHAT_TEMPLATE_FILE,
58+
LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE,
5559
PROCESSOR_NAME,
5660
PushToHubMixin,
5761
TensorType,
@@ -63,6 +67,7 @@
6367
download_url,
6468
is_offline_mode,
6569
is_remote_url,
70+
list_repo_templates,
6671
logging,
6772
)
6873

@@ -618,13 +623,19 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
618623
configs.append(self)
619624
custom_object_save(self, save_directory, config=configs)
620625

626+
save_jinja_files = kwargs.get("save_jinja_files", True)
627+
621628
for attribute_name in self.attributes:
622629
attribute = getattr(self, attribute_name)
623630
# Include the processor class in the attribute config so this processor can then be reloaded with the
624631
# `AutoProcessor` API.
625632
if hasattr(attribute, "_set_processor_class"):
626633
attribute._set_processor_class(self.__class__.__name__)
627-
attribute.save_pretrained(save_directory)
634+
if attribute_name == "tokenizer":
635+
# Propagate save_jinja_files to tokenizer to ensure we don't get conflicts
636+
attribute.save_pretrained(save_directory, save_jinja_files=save_jinja_files)
637+
else:
638+
attribute.save_pretrained(save_directory)
628639

629640
if self._auto_class is not None:
630641
# We added an attribute to the init_kwargs of the tokenizers, which needs to be cleaned up.
@@ -636,24 +647,52 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
636647
# If we save using the predefined names, we can load using `from_pretrained`
637648
# plus we save chat_template in its own file
638649
output_processor_file = os.path.join(save_directory, PROCESSOR_NAME)
639-
output_raw_chat_template_file = os.path.join(save_directory, "chat_template.jinja")
640-
output_chat_template_file = os.path.join(save_directory, "chat_template.json")
650+
output_chat_template_file_jinja = os.path.join(save_directory, CHAT_TEMPLATE_FILE)
651+
output_chat_template_file_legacy = os.path.join(
652+
save_directory, LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE
653+
) # Legacy filename
654+
chat_template_dir = os.path.join(save_directory, CHAT_TEMPLATE_DIR)
641655

642656
processor_dict = self.to_dict()
643657
# Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
644658
# to avoid serializing chat template in json config file. So let's get it from `self` directly
645659
if self.chat_template is not None:
646-
if kwargs.get("save_raw_chat_template", False):
647-
with open(output_raw_chat_template_file, "w", encoding="utf-8") as writer:
648-
writer.write(self.chat_template)
649-
logger.info(f"chat template saved in {output_raw_chat_template_file}")
650-
else:
660+
save_jinja_files = kwargs.get("save_jinja_files", True)
661+
is_single_template = isinstance(self.chat_template, str)
662+
if save_jinja_files and is_single_template:
663+
# New format for single templates is to save them as chat_template.jinja
664+
with open(output_chat_template_file_jinja, "w", encoding="utf-8") as f:
665+
f.write(self.chat_template)
666+
logger.info(f"chat template saved in {output_chat_template_file_jinja}")
667+
elif save_jinja_files and not is_single_template:
668+
# New format for multiple templates is to save the default as chat_template.jinja
669+
# and the other templates in the chat_templates/ directory
670+
for template_name, template in self.chat_template.items():
671+
if template_name == "default":
672+
with open(output_chat_template_file_jinja, "w", encoding="utf-8") as f:
673+
f.write(self.chat_template["default"])
674+
logger.info(f"chat template saved in {output_chat_template_file_jinja}")
675+
else:
676+
os.makedirs(chat_template_dir, exist_ok=True)
677+
template_filepath = os.path.join(chat_template_dir, f"{template_name}.jinja")
678+
with open(template_filepath, "w", encoding="utf-8") as f:
679+
f.write(template)
680+
logger.info(f"chat template saved in {template_filepath}")
681+
elif is_single_template:
682+
# Legacy format for single templates: Put them in chat_template.json
651683
chat_template_json_string = (
652684
json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n"
653685
)
654-
with open(output_chat_template_file, "w", encoding="utf-8") as writer:
686+
with open(output_chat_template_file_legacy, "w", encoding="utf-8") as writer:
655687
writer.write(chat_template_json_string)
656-
logger.info(f"chat template saved in {output_chat_template_file}")
688+
logger.info(f"chat template saved in {output_chat_template_file_legacy}")
689+
elif self.chat_template is not None:
690+
# At this point we have multiple templates in the legacy format, which is not supported
691+
# chat template dicts are saved to chat_template.json as lists of dicts with fixed key names.
692+
raise ValueError(
693+
"Multiple chat templates are not supported in the legacy format. Please save them as "
694+
"separate files using the `save_jinja_files` argument."
695+
)
657696

658697
# For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and
659698
# `auto_map` is not specified.
@@ -717,6 +756,8 @@ def get_processor_dict(
717756
if os.path.isdir(pretrained_model_name_or_path):
718757
processor_file = os.path.join(pretrained_model_name_or_path, PROCESSOR_NAME)
719758

759+
additional_chat_template_files = {}
760+
resolved_additional_chat_template_files = {}
720761
if os.path.isfile(pretrained_model_name_or_path):
721762
resolved_processor_file = pretrained_model_name_or_path
722763
# cant't load chat-template when given a file as pretrained_model_name_or_path
@@ -730,9 +771,25 @@ def get_processor_dict(
730771
resolved_chat_template_file = None
731772
resolved_raw_chat_template_file = None
732773
else:
774+
if is_local:
775+
template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR)
776+
if template_dir.is_dir():
777+
for template_file in template_dir.glob("*.jinja"):
778+
template_name = template_file.stem
779+
additional_chat_template_files[template_name] = f"{CHAT_TEMPLATE_DIR}/{template_file.name}"
780+
else:
781+
try:
782+
for template in list_repo_templates(
783+
pretrained_model_name_or_path,
784+
local_files_only=local_files_only,
785+
revision=revision,
786+
cache_dir=cache_dir,
787+
):
788+
additional_chat_template_files[template] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja"
789+
except EntryNotFoundError:
790+
pass # No template dir means no template files
733791
processor_file = PROCESSOR_NAME
734-
chat_template_file = "chat_template.json"
735-
raw_chat_template_file = "chat_template.jinja"
792+
736793
try:
737794
# Load from local folder or from cache or download from model Hub and cache
738795
resolved_processor_file = cached_file(
@@ -750,12 +807,11 @@ def get_processor_dict(
750807
_raise_exceptions_for_missing_entries=False,
751808
)
752809

753-
# Load chat template from a separate json if exists
754-
# because making it part of processor-config break BC.
755-
# Processors in older version do not accept any kwargs
810+
# chat_template.json is a legacy file used by the processor class
811+
# a raw chat_template.jinja is preferred in future
756812
resolved_chat_template_file = cached_file(
757813
pretrained_model_name_or_path,
758-
chat_template_file,
814+
LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE,
759815
cache_dir=cache_dir,
760816
force_download=force_download,
761817
proxies=proxies,
@@ -770,7 +826,7 @@ def get_processor_dict(
770826

771827
resolved_raw_chat_template_file = cached_file(
772828
pretrained_model_name_or_path,
773-
raw_chat_template_file,
829+
CHAT_TEMPLATE_FILE,
774830
cache_dir=cache_dir,
775831
force_download=force_download,
776832
proxies=proxies,
@@ -782,6 +838,24 @@ def get_processor_dict(
782838
subfolder=subfolder,
783839
_raise_exceptions_for_missing_entries=False,
784840
)
841+
842+
resolved_additional_chat_template_files = {
843+
template_name: cached_file(
844+
pretrained_model_name_or_path,
845+
template_file,
846+
cache_dir=cache_dir,
847+
force_download=force_download,
848+
proxies=proxies,
849+
resume_download=resume_download,
850+
local_files_only=local_files_only,
851+
token=token,
852+
user_agent=user_agent,
853+
revision=revision,
854+
subfolder=subfolder,
855+
_raise_exceptions_for_missing_entries=False,
856+
)
857+
for template_name, template_file in additional_chat_template_files.items()
858+
}
785859
except OSError:
786860
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
787861
# the original exception.
@@ -796,15 +870,31 @@ def get_processor_dict(
796870
)
797871

798872
# Add chat template as kwarg before returning because most models don't have processor config
799-
if resolved_raw_chat_template_file is not None:
800-
with open(resolved_raw_chat_template_file, encoding="utf-8") as reader:
801-
chat_template = reader.read()
802-
kwargs["chat_template"] = chat_template
803-
elif resolved_chat_template_file is not None:
873+
if resolved_chat_template_file is not None:
874+
# This is the legacy path
804875
with open(resolved_chat_template_file, encoding="utf-8") as reader:
805-
text = reader.read()
806-
chat_template = json.loads(text)["chat_template"]
807-
kwargs["chat_template"] = chat_template
876+
chat_template_json = json.loads(reader.read())
877+
chat_templates = {"default": chat_template_json["chat_template"]}
878+
if resolved_additional_chat_template_files:
879+
raise ValueError(
880+
"Cannot load chat template due to conflicting files - this checkpoint combines "
881+
"a legacy chat_template.json file with separate template files, which is not "
882+
"supported. To resolve this error, replace the legacy chat_template.json file "
883+
"with a modern chat_template.jinja file."
884+
)
885+
else:
886+
chat_templates = {
887+
template_name: open(template_file, "r", encoding="utf-8").read()
888+
for template_name, template_file in resolved_additional_chat_template_files.items()
889+
}
890+
if resolved_raw_chat_template_file is not None:
891+
with open(resolved_raw_chat_template_file, "r", encoding="utf-8") as reader:
892+
chat_templates["default"] = reader.read()
893+
if isinstance(chat_templates, dict) and "default" in chat_templates and len(chat_templates) == 1:
894+
chat_templates = chat_templates["default"] # Flatten when we just have a single template/file
895+
896+
if chat_templates:
897+
kwargs["chat_template"] = chat_templates
808898

809899
# Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
810900
# updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
@@ -1313,14 +1403,27 @@ def apply_chat_template(
13131403
"""
13141404

13151405
if chat_template is None:
1316-
if self.chat_template is not None:
1406+
if isinstance(self.chat_template, dict) and "default" in self.chat_template:
1407+
chat_template = self.chat_template["default"]
1408+
elif isinstance(self.chat_template, dict):
1409+
raise ValueError(
1410+
'The processor has multiple chat templates but none of them are named "default". You need to specify'
1411+
" which one to use by passing the `chat_template` argument. Available templates are: "
1412+
f"{', '.join(self.chat_template.keys())}"
1413+
)
1414+
elif self.chat_template is not None:
13171415
chat_template = self.chat_template
13181416
else:
13191417
raise ValueError(
1320-
"No chat template is set for this processor. Please either set the `chat_template` attribute, "
1321-
"or provide a chat template as an argument. See "
1322-
"https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
1418+
"Cannot use apply_chat_template because this processor does not have a chat template."
13231419
)
1420+
else:
1421+
if isinstance(self.chat_template, dict) and chat_template in self.chat_template:
1422+
# It's the name of a template, not a full template string
1423+
chat_template = self.chat_template[chat_template]
1424+
else:
1425+
# It's a template string, render it directly
1426+
chat_template = chat_template
13241427

13251428
# Fill sets of kwargs that should be used by different parts of template
13261429
processed_kwargs = {

0 commit comments

Comments
 (0)