Skip to content

Commit

Permalink
Fix chatml template (#1248)
Browse files Browse the repository at this point in the history
* first draft

* 64

* sourabs suggestion

* wip tests

* make style happy

* add check

* docstring

* fix docstring

* Update tests/test_model_utils.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* move tests

* add todo for abstract class

* make style happy

* add slow tests and imports

* add documentation

* sft_trainer.mdx aktualisieren

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* fix template & add test

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
  • Loading branch information
philschmid and younesbelkada authored Jan 18, 2024
1 parent 928d144 commit 1f59eeb
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
17 changes: 17 additions & 0 deletions tests/test_dataset_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,20 @@ def test_setup_chat_format(self):
self.assertTrue(len(modified_tokenizer) == original_tokenizer_len + 2)
self.assertTrue(self.model.get_input_embeddings().weight.shape[0] % 64 == 0)
self.assertTrue(self.model.get_input_embeddings().weight.shape[0] == original_tokenizer_len + 64)

def test_example_with_setup_model(self):
modified_model, modified_tokenizer = setup_chat_format(
self.model,
self.tokenizer,
)
messages = [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help you?"},
]
prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False)

self.assertEqual(
prompt,
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n",
)
4 changes: 2 additions & 2 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def assistant(self):
def chat_template(self):
return (
"{% for message in messages %}"
f"{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + eos_token + '\n'}}"
f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
f"{{ '{self.assistant}\n' }}"
f"{{{{ '{self.assistant}\n' }}}}"
"{% endif %}"
)

Expand Down

0 comments on commit 1f59eeb

Please sign in to comment.