-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update incorrect data processing in DataCollatorForChatML #2172
Conversation
Fix the extra BOS token and the absence of an EOS token in the returned input_ids, and potentially the absence of a target string in the returned labels.
…orForChatML Update incorrect data processing in DataCollatorForChatML
awesome @ruijunfeng can we also have a test for this? |
Sure thing, I have tested it on the instruct-tuned version of Llama2 series and gemma1 series with my own dataset, and it seems to work well. Let me know if you need me to provide anything. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
sorry for the misunderstanding, I meant something like: class TestDataCollatorForChatML(unittest.TestCase):
def setUp(self):
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.tokenizer.pad_token = self.tokenizer.eos_token
self.collator = DataCollatorForChatML(tokenizer=self.tokenizer, max_length=20)
def test_data_collator(self):
examples = [
{
"messages": [
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there! How can I help you today?"},
{"role": "user", "content": "What's the weather like?"},
{"role": "assistant", "content": "I'm sorry, but I don't have access to real-time weather information."},
]
},
{
"messages": [
{"role": "user", "content": "Tell me a joke."},
{"role": "assistant", "content": "Why don't scientists trust atoms? Because they make up everything!"},
]
}
]
batch = self.collator(examples)
self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch)
self.assertIn("labels", batch)
self.assertIn("prompts", batch)
self.assertIn("prompt_attention_mask", batch)
self.assertEqual(batch["input_ids"].shape[0], 2)
self.assertEqual(batch["attention_mask"].shape[0], 2)
self.assertEqual(batch["labels"].shape[0], 2)
self.assertEqual(batch["prompts"].shape[0], 2)
self.assertEqual(batch["prompt_attention_mask"].shape[0], 2)
# Check if the shapes are consistent
self.assertEqual(batch["input_ids"].shape, batch["attention_mask"].shape)
self.assertEqual(batch["input_ids"].shape, batch["labels"].shape)
self.assertEqual(batch["prompts"].shape, batch["prompt_attention_mask"].shape)
# Check if the prompts are shorter than or equal to the full input
self.assertTrue((batch["prompts"].shape[1] <= batch["input_ids"].shape[1]).all()) so we can explicitly check for the incorrect data processing and the fix you so kindly provided |
Hi there, I have run your test code, and I think your test code has a small mistake. You are using the tokenizer for GPT-2: self.tokenizer = AutoTokenizer.from_pretrained("gpt2") However, GPT-2 does not have a default chat_template, so it will cause an error in this line of DataCollartorForChatML self.tokenizer.apply_chat_template(messages, tokenize=False). I believe the correct way to test this is by manually setting the chat_template for the tokenizer, like this in your setup function: tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{{ message['role'] }}: {{ message['content'] }}{% endfor %}{{ eos_token }}". Alternatively, you could use a model that has been fine-tuned on instructions, such as Llama-Instructed, whose tokenizer has a default chat_template. |
sorry again for the misunderstanding, what i wanted to say is that you can use the above as a template to write the tests in your PR, and also do remember to do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the fix @ruijunfeng ! Would you mind adding a unit test which validates the fix works as expected? This will also help ensure future regressions don't leak into the codebase :)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
@lewtun i added a test that fails on main and passes here and @ruijunfeng I pushed it into your PR |
…st token of input_ids should be EOS token
@kashif and @lewtun, thank you both for adding the tests and comments. I’ve double-checked the tests and made updates to the assert statements and comments to improve consistency and clarity. Additionally, I noticed that the current check for the EOS token in input_ids only verifies its presence. I have modified it to ensure that the last token of input_ids is the EOS token for a more thorough check. |
@qgallouedec fixed the test taking padding into account |
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
@qgallouedec I have:
|
@ruijunfeng can you kindly check with the current refactoring of the datacollator, I have simplified it |
@kashif Hi there, I still found a small bug in the refactor code. I used the dataset in your unit test and print out the results like this: >>> tokenizer.decode(data["input_ids"][0])
'<s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s>user: What is better than ugly?assistant: Beautiful.</s>'
>>> tokenizer.decode(data["prompts"][0])
'<s><s><s><s><s><s><s>user: What is better than ugly?</s>'
>>> data["labels"][0, -6:]
tensor([ -100, 22137, 29901, 25685, 29889, 2])
>>> tokenizer.decode(data["labels"][0, -5:])
'istant: Beautiful.</s>' Seems like the labels mistakenly wrap part of the "assistant: " and the prompts has missed the "assistant: ". Also from my understanding, isn't prompts shouldn't include EOS token? |
just trying to reproduce this on my end, I have as output from the data collator:
and the labels are only set for the completion:
at which point are you printing the data from? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I've just added a minor suggestion
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
@kashif Sorry I used a wrong version of the code. I have tried it again and the refactored code is all good. |
* Update incorrect data processing in DataCollatorForChatML Fix the extra BOS token and the absence of an EOS token in the returned input_ids, and potentially the absence of a target string in the returned labels. * Update trl/trainer/utils.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * style * move comment * add test for DataCollatorForChatML * update comment with more details * update assert reports and comments, and adds verification that the last token of input_ids should be EOS token * new line at the end of file for code quality * Update tests/test_utils.py * Update tests/test_utils.py * Update tests/test_utils.py * update tests * fix test * Update tests/test_utils.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Update tests/test_utils.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * formatting * fix typo * simplify * Revert "simplify" This reverts commit 7e4006c. * tokenize full messages * dont add eos * eos is in the last token * simplify DataCollatorForChatML * Update tests/test_utils.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
What does this PR do?
Fix the extra BOS token and the absence of an EOS token in the returned input_ids, and potentially the absence of a target string in the returned labels.
Fixes #2169
Before submitting
Pull Request section?
to it if that's the case. (Incorrect data processing in DataCollatorForChatML #2169)
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.