Skip to content

Commit 313afcc

Browse files
authored
[chat template] update when "push_to_hub" (#39815)
* update templates push to hub * rvert jinja suffix and move it to processor file
1 parent 7bba4d1 commit 313afcc

File tree

5 files changed

+74
-5
lines changed

5 files changed

+74
-5
lines changed

src/transformers/processing_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,7 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, legacy_seri
776776
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
777777
"""
778778
use_auth_token = kwargs.pop("use_auth_token", None)
779+
save_jinja_files = kwargs.pop("save_jinja_files", True)
779780

780781
if use_auth_token is not None:
781782
warnings.warn(
@@ -803,8 +804,6 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, legacy_seri
803804
configs.append(self)
804805
custom_object_save(self, save_directory, config=configs)
805806

806-
save_jinja_files = kwargs.get("save_jinja_files", True)
807-
808807
for attribute_name in self.attributes:
809808
# Save the tokenizer in its own vocab file. The other attributes are saved as part of `processor_config.json`
810809
if attribute_name == "tokenizer":
@@ -840,7 +839,6 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, legacy_seri
840839
# Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
841840
# to avoid serializing chat template in json config file. So let's get it from `self` directly
842841
if self.chat_template is not None:
843-
save_jinja_files = kwargs.get("save_jinja_files", True)
844842
is_single_template = isinstance(self.chat_template, str)
845843
if save_jinja_files and is_single_template:
846844
# New format for single templates is to save them as chat_template.jinja
@@ -999,6 +997,7 @@ def get_processor_dict(
999997
cache_dir=cache_dir,
1000998
token=token,
1001999
):
1000+
template = template.removesuffix(".jinja")
10021001
additional_chat_template_files[template] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja"
10031002
except EntryNotFoundError:
10041003
pass # No template dir means no template files

src/transformers/tokenization_utils_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2512,6 +2512,7 @@ def save_pretrained(
25122512
A tuple of `str`: The files saved.
25132513
"""
25142514
use_auth_token = kwargs.pop("use_auth_token", None)
2515+
save_jinja_files = kwargs.pop("save_jinja_files", True)
25152516

25162517
if use_auth_token is not None:
25172518
warnings.warn(
@@ -2560,7 +2561,6 @@ def save_pretrained(
25602561
tokenizer_config["extra_special_tokens"] = self.extra_special_tokens
25612562
tokenizer_config.update(self.extra_special_tokens)
25622563

2563-
save_jinja_files = kwargs.get("save_jinja_files", True)
25642564
tokenizer_config, saved_raw_chat_template_files = self.save_chat_templates(
25652565
save_directory, tokenizer_config, filename_prefix, save_jinja_files
25662566
)

src/transformers/utils/hub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def list_repo_templates(
163163
local_files_only: bool,
164164
revision: Optional[str] = None,
165165
cache_dir: Optional[str] = None,
166-
token: Union[bool, str, None] = None,
166+
token: Optional[Union[str, bool]] = None,
167167
) -> list[str]:
168168
"""List template files from a repo.
169169

tests/models/auto/test_processor_auto.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434
AutoProcessor,
3535
AutoTokenizer,
3636
BertTokenizer,
37+
LlamaTokenizer,
38+
LlavaProcessor,
3739
ProcessorMixin,
40+
SiglipImageProcessor,
3841
Wav2Vec2Config,
3942
Wav2Vec2FeatureExtractor,
4043
Wav2Vec2Processor,
@@ -57,6 +60,7 @@
5760

5861

5962
SAMPLE_PROCESSOR_CONFIG = get_tests_dir("fixtures/dummy_feature_extractor_config.json")
63+
SAMPLE_VOCAB_LLAMA = get_tests_dir("fixtures/test_sentencepiece.model")
6064
SAMPLE_VOCAB = get_tests_dir("fixtures/vocab.json")
6165
SAMPLE_PROCESSOR_CONFIG_DIR = get_tests_dir("fixtures")
6266

@@ -503,3 +507,43 @@ def test_push_to_hub_dynamic_processor(self):
503507
new_processor = AutoProcessor.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
504508
# Can't make an isinstance check because the new_processor is from the CustomProcessor class of a dynamic module
505509
self.assertEqual(new_processor.__class__.__name__, "CustomProcessor")
510+
511+
def test_push_to_hub_with_chat_templates(self):
512+
with tempfile.TemporaryDirectory() as tmp_dir:
513+
tokenizer = LlamaTokenizer(SAMPLE_VOCAB_LLAMA, keep_accents=True)
514+
image_processor = SiglipImageProcessor()
515+
chat_template = "default dummy template for testing purposes only"
516+
processor = LlavaProcessor(
517+
tokenizer=tokenizer, image_processor=image_processor, chat_template=chat_template
518+
)
519+
self.assertEqual(processor.chat_template, chat_template)
520+
521+
existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None)
522+
with TemporaryHubRepo(token=self._token) as tmp_repo:
523+
processor.save_pretrained(
524+
tmp_dir, repo_id=tmp_repo.repo_id, token=self._token, push_to_hub=True, save_jinja_files=False
525+
)
526+
reloaded_processor = LlavaProcessor.from_pretrained(tmp_repo.repo_id)
527+
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
528+
# When we don't use single-file chat template saving, processor and tokenizer chat templates
529+
# should remain separate
530+
self.assertEqual(
531+
getattr(reloaded_processor.tokenizer, "chat_template", None), existing_tokenizer_template
532+
)
533+
534+
with TemporaryHubRepo(token=self._token) as tmp_repo:
535+
processor.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, token=self._token, push_to_hub=True)
536+
reloaded_processor = LlavaProcessor.from_pretrained(tmp_repo.repo_id)
537+
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
538+
# When we save as single files, tokenizers and processors share a chat template, which means
539+
# the reloaded tokenizer should get the chat template as well
540+
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
541+
542+
with TemporaryHubRepo(token=self._token) as tmp_repo:
543+
processor.chat_template = {"default": "a", "secondary": "b"}
544+
processor.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, token=self._token, push_to_hub=True)
545+
reloaded_processor = LlavaProcessor.from_pretrained(tmp_repo.repo_id)
546+
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
547+
# When we save as single files, tokenizers and processors share a chat template, which means
548+
# the reloaded tokenizer should get the chat template as well
549+
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)

tests/utils/test_tokenization_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,32 @@ def test_push_to_hub(self):
131131
new_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
132132
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
133133

134+
def test_push_to_hub_chat_templates(self):
135+
with tempfile.TemporaryDirectory() as tmp_dir:
136+
vocab_file = os.path.join(tmp_dir, "vocab.txt")
137+
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
138+
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
139+
tokenizer = BertTokenizer(vocab_file)
140+
tokenizer.chat_template = "test template"
141+
142+
with TemporaryHubRepo(token=self._token) as tmp_repo:
143+
tokenizer.save_pretrained(
144+
tmp_repo.repo_id, token=self._token, push_to_hub=True, save_jinja_files=False
145+
)
146+
reloaded_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
147+
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
148+
149+
with TemporaryHubRepo(token=self._token) as tmp_repo:
150+
tokenizer.save_pretrained(tmp_repo.repo_id, token=self._token, push_to_hub=True)
151+
reloaded_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
152+
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
153+
154+
with TemporaryHubRepo(token=self._token) as tmp_repo:
155+
tokenizer.chat_template = {"default": "a", "secondary": "b"}
156+
tokenizer.save_pretrained(tmp_repo.repo_id, token=self._token, push_to_hub=True)
157+
reloaded_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
158+
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
159+
134160
def test_push_to_hub_via_save_pretrained(self):
135161
with TemporaryHubRepo(token=self._token) as tmp_repo:
136162
with tempfile.TemporaryDirectory() as tmp_dir:

0 commit comments

Comments
 (0)