Skip to content

Commit f9ecaf2

Browse files
committed
make fixup
1 parent a4542f6 commit f9ecaf2

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

tests/test_tokenization_common.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1626,20 +1626,28 @@ def test_chat_template_dict_saving(self):
16261626
for tokenizer in tokenizers:
16271627
with self.subTest(f"{tokenizer.__class__.__name__}"):
16281628
for save_raw_chat_template in (True, False):
1629-
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
1629+
tokenizer.chat_template = {"default": dummy_template_1, "template2": dummy_template_2}
16301630
with tempfile.TemporaryDirectory() as tmp_dir_name:
16311631
# Test that save_raw_chat_template is ignored when there's a dict of multiple templates
16321632
tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=save_raw_chat_template)
1633-
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
1634-
# Assert that chat templates are correctly serialized as lists of dictionaries
1635-
self.assertEqual(
1636-
config_dict["chat_template"],
1637-
[
1638-
{"name": "template1", "template": "{{'a'}}"},
1639-
{"name": "template2", "template": "{{'b'}}"},
1640-
],
1641-
)
1642-
self.assertFalse(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
1633+
if save_raw_chat_template:
1634+
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
1635+
self.assertNotIn("chat_template", config_dict)
1636+
self.assertTrue(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
1637+
self.assertTrue(
1638+
os.path.exists(os.path.join(tmp_dir_name, "additional_chat_templates/template2.jinja"))
1639+
)
1640+
else:
1641+
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
1642+
# Assert that chat templates are correctly serialized as lists of dictionaries
1643+
self.assertEqual(
1644+
config_dict["chat_template"],
1645+
[
1646+
{"name": "default", "template": "{{'a'}}"},
1647+
{"name": "template2", "template": "{{'b'}}"},
1648+
],
1649+
)
1650+
self.assertFalse(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja")))
16431651
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
16441652
# Assert that the serialized list is correctly reconstructed as a single dict
16451653
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)

0 commit comments

Comments
 (0)