Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

👬 Rename collator PreferenceCollator to DataCollatorForPreference #2510

Merged
merged 3 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,6 @@ dpo_trainer = DPOTrainer(

[[autodoc]] DPOConfig

## PreferenceCollator
## DataCollatorForPreference

[[autodoc]] trainer.dpo_trainer.PreferenceCollator
[[autodoc]] trainer.dpo_trainer.DataCollatorForPreference
74 changes: 74 additions & 0 deletions tests/test_collators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from trl.trainer.dpo_trainer import DataCollatorForPreference


class TestDataCollatorForPreference(unittest.TestCase):
def setUp(self):
self.collator = DataCollatorForPreference(pad_token_id=0)

def assertTensorEqual(self, tensor1, tensor2):
self.assertTrue(torch.equal(tensor1, tensor2), f"Tensors are not equal:\n{tensor1}\n{tensor2}")

def test_padding_behavior(self):
examples = [
{"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6]},
{"prompt_input_ids": [7, 8], "chosen_input_ids": [9, 10], "rejected_input_ids": [11, 12, 13]},
]
output = self.collator.torch_call(examples)

expected_prompt_input_ids = torch.tensor([[1, 2, 3], [0, 7, 8]])
expected_prompt_attention_mask = torch.tensor([[1, 1, 1], [0, 1, 1]])
expected_chosen_input_ids = torch.tensor([[4, 5], [9, 10]])
expected_chosen_attention_mask = torch.tensor([[1, 1], [1, 1]])
expected_rejected_input_ids = torch.tensor([[6, 0, 0], [11, 12, 13]])
expected_rejected_attention_mask = torch.tensor([[1, 0, 0], [1, 1, 1]])

self.assertTensorEqual(output["prompt_input_ids"], expected_prompt_input_ids)
self.assertTensorEqual(output["prompt_attention_mask"], expected_prompt_attention_mask)
self.assertTensorEqual(output["chosen_input_ids"], expected_chosen_input_ids)
self.assertTensorEqual(output["chosen_attention_mask"], expected_chosen_attention_mask)
self.assertTensorEqual(output["rejected_input_ids"], expected_rejected_input_ids)
self.assertTensorEqual(output["rejected_attention_mask"], expected_rejected_attention_mask)

def test_optional_fields(self):
examples = [
{
"prompt_input_ids": [1],
"chosen_input_ids": [2],
"rejected_input_ids": [3],
"pixel_values": [[[0.1, 0.2], [0.3, 0.4]]], # Example 3D tensor (1x2x2)
},
{
"prompt_input_ids": [4],
"chosen_input_ids": [5],
"rejected_input_ids": [6],
"pixel_values": [[[0.5, 0.6], [0.7, 0.8]]], # Example 3D tensor (1x2x2)
},
]
output = self.collator.torch_call(examples)

expected_pixel_values = torch.tensor(
[
[[[0.1, 0.2], [0.3, 0.4]]],
[[[0.5, 0.6], [0.7, 0.8]]],
]
) # Shape: (2, 1, 2, 2)

self.assertTensorEqual(output["pixel_values"], expected_pixel_values)
12 changes: 6 additions & 6 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@


@dataclass
class PreferenceCollator(DataCollatorMixin):
class DataCollatorForPreference(DataCollatorMixin):
"""
Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch if they
are not all of the same length.
Expand All @@ -92,8 +92,8 @@ class PreferenceCollator(DataCollatorMixin):

Examples:
```python
>>> from trl import PreferenceCollator
>>> collator = PreferenceCollator(pad_token_id=0)
>>> from trl import DataCollatorForPreference
>>> collator = DataCollatorForPreference(pad_token_id=0)
>>> examples = [
... {"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6]},
... {"prompt_input_ids": [7, 8], "chosen_input_ids": [9, 10], "rejected_input_ids": [11, 12, 13]}
Expand Down Expand Up @@ -168,7 +168,7 @@ class DPOTrainer(Trainer):
args (`DPOConfig`):
The DPO config arguments to use for training.
data_collator (`transformers.DataCollator`):
The data collator to use for training. If None is specified, the default data collator (`PreferenceCollator`) will be used
The data collator to use for training. If None is specified, the default data collator (`DataCollatorForPreference`) will be used
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
train_dataset (`datasets.Dataset`):
The dataset to use for training.
Expand Down Expand Up @@ -374,7 +374,7 @@ def make_inputs_require_grad(module, input, output):
)

if data_collator is None:
data_collator = PreferenceCollator(pad_token_id=self.padding_value)
data_collator = DataCollatorForPreference(pad_token_id=self.padding_value)

# Disable dropout in the model and reference model
if args.disable_dropout:
Expand Down Expand Up @@ -684,7 +684,7 @@ def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
# In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
# Instead, we set them to the columns expected by `PreferenceCollator`, hence the override.
# Instead, we set them to the columns expected by `DataCollatorForPreference`, hence the override.
if self._signature_columns is None:
self._signature_columns = [
"prompt_input_ids",
Expand Down
Loading