Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update incorrect data processing in DataCollatorForChatML #2172

Merged
merged 31 commits into from
Oct 10, 2024

Conversation

ruijunfeng
Copy link
Contributor

What does this PR do?

Fix the extra BOS token and the absence of an EOS token in the returned input_ids, and potentially the absence of a target string in the returned labels.

Fixes #2169

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Fix the extra BOS token and the absence of an EOS token in the returned input_ids, and potentially the absence of a target string in the returned labels.
…orForChatML

Update incorrect data processing in DataCollatorForChatML
@kashif
Copy link
Collaborator

kashif commented Oct 4, 2024

awesome @ruijunfeng can we also have a test for this?

@ruijunfeng
Copy link
Contributor Author

ruijunfeng commented Oct 4, 2024

awesome @ruijunfeng can we also have a test for this?

Sure thing, I have tested it on the instruct-tuned version of Llama2 series and gemma1 series with my own dataset, and it seems to work well. Let me know if you need me to provide anything.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif
Copy link
Collaborator

kashif commented Oct 6, 2024

sorry for the misunderstanding, I meant something like:

class TestDataCollatorForChatML(unittest.TestCase):
    def setUp(self):
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.collator = DataCollatorForChatML(tokenizer=self.tokenizer, max_length=20)

    def test_data_collator(self):
        examples = [
            {
                "messages": [
                    {"role": "user", "content": "Hello!"},
                    {"role": "assistant", "content": "Hi there! How can I help you today?"},
                    {"role": "user", "content": "What's the weather like?"},
                    {"role": "assistant", "content": "I'm sorry, but I don't have access to real-time weather information."},
                ]
            },
            {
                "messages": [
                    {"role": "user", "content": "Tell me a joke."},
                    {"role": "assistant", "content": "Why don't scientists trust atoms? Because they make up everything!"},
                ]
            }
        ]

        batch = self.collator(examples)

        self.assertIn("input_ids", batch)
        self.assertIn("attention_mask", batch)
        self.assertIn("labels", batch)
        self.assertIn("prompts", batch)
        self.assertIn("prompt_attention_mask", batch)

        self.assertEqual(batch["input_ids"].shape[0], 2)
        self.assertEqual(batch["attention_mask"].shape[0], 2)
        self.assertEqual(batch["labels"].shape[0], 2)
        self.assertEqual(batch["prompts"].shape[0], 2)
        self.assertEqual(batch["prompt_attention_mask"].shape[0], 2)

        # Check if the shapes are consistent
        self.assertEqual(batch["input_ids"].shape, batch["attention_mask"].shape)
        self.assertEqual(batch["input_ids"].shape, batch["labels"].shape)
        self.assertEqual(batch["prompts"].shape, batch["prompt_attention_mask"].shape)

        # Check if the prompts are shorter than or equal to the full input
        self.assertTrue((batch["prompts"].shape[1] <= batch["input_ids"].shape[1]).all())

so we can explicitly check for the incorrect data processing and the fix you so kindly provided

@kashif kashif added the 🏋 GKD Related to GKD label Oct 6, 2024
@ruijunfeng
Copy link
Contributor Author

Hi there, I have run your test code, and I think your test code has a small mistake. You are using the tokenizer for GPT-2:

self.tokenizer = AutoTokenizer.from_pretrained("gpt2")

However, GPT-2 does not have a default chat_template, so it will cause an error in this line of DataCollartorForChatML

self.tokenizer.apply_chat_template(messages, tokenize=False). 

I believe the correct way to test this is by manually setting the chat_template for the tokenizer, like this in your setup function:

tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{{ message['role'] }}: {{ message['content'] }}{% endfor %}{{ eos_token }}". 

Alternatively, you could use a model that has been fine-tuned on instructions, such as Llama-Instructed, whose tokenizer has a default chat_template.

@kashif
Copy link
Collaborator

kashif commented Oct 7, 2024

sorry again for the misunderstanding, what i wanted to say is that you can use the above as a template to write the tests in your PR, and also do remember to do make precommit to fix the formatting etc

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the fix @ruijunfeng ! Would you mind adding a unit test which validates the fix works as expected? This will also help ensure future regressions don't leak into the codebase :)

trl/trainer/utils.py Outdated Show resolved Hide resolved
@kashif kashif added the 🐛 bug Something isn't working label Oct 8, 2024
@kashif
Copy link
Collaborator

kashif commented Oct 8, 2024

@lewtun i added a test that fails on main and passes here and @ruijunfeng I pushed it into your PR

tests/test_utils.py Outdated Show resolved Hide resolved
tests/test_utils.py Outdated Show resolved Hide resolved
tests/test_utils.py Outdated Show resolved Hide resolved
@ruijunfeng
Copy link
Contributor Author

@kashif and @lewtun, thank you both for adding the tests and comments. I’ve double-checked the tests and made updates to the assert statements and comments to improve consistency and clarity. Additionally, I noticed that the current check for the EOS token in input_ids only verifies its presence. I have modified it to ensure that the last token of input_ids is the EOS token for a more thorough check.

@kashif
Copy link
Collaborator

kashif commented Oct 8, 2024

@qgallouedec fixed the test taking padding into account

tests/test_utils.py Outdated Show resolved Hide resolved
tests/test_utils.py Outdated Show resolved Hide resolved
kashif and others added 4 commits October 8, 2024 17:16
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
@kashif
Copy link
Collaborator

kashif commented Oct 9, 2024

@qgallouedec I have:

input_ids[15:]  # 15 first tokens are padding
[1, 518, 25580, 29962, 1724, 338, 2253, 1135, 22769, 29973, 518, 29914, 25580, 29962, 25685, 29889, 29871, 2]
(Pdb) self.tokenizer(self.tokenizer.apply_chat_template(self.examples[0]["messages"], tokenize=False), add_special_tokens=False)
{'input_ids': [1, 518, 25580, 29962, 1724, 338, 2253, 1135, 22769, 29973, 518, 29914, 25580, 29962, 25685, 29889, 29871, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

@kashif kashif self-requested a review October 9, 2024 11:07
@kashif
Copy link
Collaborator

kashif commented Oct 9, 2024

@ruijunfeng can you kindly check with the current refactoring of the datacollator, I have simplified it

@ruijunfeng
Copy link
Contributor Author

ruijunfeng commented Oct 10, 2024

@kashif Hi there, I still found a small bug in the refactor code. I used the dataset in your unit test and print out the results like this:

>>> tokenizer.decode(data["input_ids"][0])
'<s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s>user: What is better than ugly?assistant: Beautiful.</s>'
>>> tokenizer.decode(data["prompts"][0])
'<s><s><s><s><s><s><s>user: What is better than ugly?</s>'
>>> data["labels"][0, -6:]
tensor([ -100, 22137, 29901, 25685, 29889,     2])
>>> tokenizer.decode(data["labels"][0, -5:])
'istant: Beautiful.</s>'

Seems like the labels mistakenly wrap part of the "assistant: " and the prompts has missed the "assistant: ". Also from my understanding, isn't prompts shouldn't include EOS token?

@kashif
Copy link
Collaborator

kashif commented Oct 10, 2024

just trying to reproduce this on my end, I have as output from the data collator:

self.tokenizer.decode(input_ids[0])
'</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><s> [INST] What is better than ugly? [/INST] Beautiful. </s>'
self.tokenizer.decode(prompts_input_ids[0])
'</s></s></s></s></s></s><s> [INST] What is better than ugly? [/INST]'

and the labels are only set for the completion:

labels[0]
tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 25685,
        29889, 29871,     2])
self.tokenizer.decode(labels[0][-4:])
'Beautiful. </s>'

at which point are you printing the data from?

tests/test_utils.py Outdated Show resolved Hide resolved
Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I've just added a minor suggestion

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
@ruijunfeng
Copy link
Contributor Author

@kashif Sorry I used a wrong version of the code. I have tried it again and the refactored code is all good.

@kashif kashif merged commit 3107a40 into huggingface:main Oct 10, 2024
9 checks passed
qgallouedec added a commit that referenced this pull request Oct 10, 2024
* Update incorrect data processing in DataCollatorForChatML

Fix the extra BOS token and the absence of an EOS token in the returned input_ids, and potentially the absence of a target string in the returned labels.

* Update trl/trainer/utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* style

* move comment

* add test for DataCollatorForChatML

* update comment with more details

* update assert reports and comments, and adds verification that the last token of input_ids should be EOS token

* new line at the end of file for code quality

* Update tests/test_utils.py

* Update tests/test_utils.py

* Update tests/test_utils.py

* update tests

* fix test

* Update tests/test_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update tests/test_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* formatting

* fix typo

* simplify

* Revert "simplify"

This reverts commit 7e4006c.

* tokenize full messages

* dont add eos

* eos is in the last token

* simplify DataCollatorForChatML

* Update tests/test_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 GKD Related to GKD
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorrect data processing in DataCollatorForChatML
5 participants