From 99451b421a007af58c5d58feb9f5477310083c8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sun, 22 Dec 2024 12:43:55 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=AC=20Rename=20collator=20`PreferenceC?= =?UTF-8?q?ollator`=20to=20`=20DataCollatorForPreference`=20(#2510)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/dpo_trainer.mdx | 4 +- tests/test_collators.py | 74 +++++++++++++++++++++++++++++++++++++ trl/trainer/dpo_trainer.py | 12 +++--- 3 files changed, 82 insertions(+), 8 deletions(-) create mode 100644 tests/test_collators.py diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index 103326fd9b..b0d6b1f8d6 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -278,6 +278,6 @@ dpo_trainer = DPOTrainer( [[autodoc]] DPOConfig -## PreferenceCollator +## DataCollatorForPreference -[[autodoc]] trainer.dpo_trainer.PreferenceCollator \ No newline at end of file +[[autodoc]] trainer.dpo_trainer.DataCollatorForPreference \ No newline at end of file diff --git a/tests/test_collators.py b/tests/test_collators.py new file mode 100644 index 0000000000..2d02d77b9d --- /dev/null +++ b/tests/test_collators.py @@ -0,0 +1,74 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from trl.trainer.dpo_trainer import DataCollatorForPreference + + +class TestDataCollatorForPreference(unittest.TestCase): + def setUp(self): + self.collator = DataCollatorForPreference(pad_token_id=0) + + def assertTensorEqual(self, tensor1, tensor2): + self.assertTrue(torch.equal(tensor1, tensor2), f"Tensors are not equal:\n{tensor1}\n{tensor2}") + + def test_padding_behavior(self): + examples = [ + {"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6]}, + {"prompt_input_ids": [7, 8], "chosen_input_ids": [9, 10], "rejected_input_ids": [11, 12, 13]}, + ] + output = self.collator.torch_call(examples) + + expected_prompt_input_ids = torch.tensor([[1, 2, 3], [0, 7, 8]]) + expected_prompt_attention_mask = torch.tensor([[1, 1, 1], [0, 1, 1]]) + expected_chosen_input_ids = torch.tensor([[4, 5], [9, 10]]) + expected_chosen_attention_mask = torch.tensor([[1, 1], [1, 1]]) + expected_rejected_input_ids = torch.tensor([[6, 0, 0], [11, 12, 13]]) + expected_rejected_attention_mask = torch.tensor([[1, 0, 0], [1, 1, 1]]) + + self.assertTensorEqual(output["prompt_input_ids"], expected_prompt_input_ids) + self.assertTensorEqual(output["prompt_attention_mask"], expected_prompt_attention_mask) + self.assertTensorEqual(output["chosen_input_ids"], expected_chosen_input_ids) + self.assertTensorEqual(output["chosen_attention_mask"], expected_chosen_attention_mask) + self.assertTensorEqual(output["rejected_input_ids"], expected_rejected_input_ids) + self.assertTensorEqual(output["rejected_attention_mask"], expected_rejected_attention_mask) + + def test_optional_fields(self): + examples = [ + { + "prompt_input_ids": [1], + "chosen_input_ids": [2], + "rejected_input_ids": [3], + "pixel_values": [[[0.1, 0.2], [0.3, 0.4]]], # Example 3D tensor (1x2x2) + }, + { + "prompt_input_ids": [4], + "chosen_input_ids": [5], + "rejected_input_ids": [6], + "pixel_values": [[[0.5, 0.6], [0.7, 0.8]]], # Example 3D tensor (1x2x2) + }, + ] + output = self.collator.torch_call(examples) + + expected_pixel_values = torch.tensor( + [ + [[[0.1, 0.2], [0.3, 0.4]]], + [[[0.5, 0.6], [0.7, 0.8]]], + ] + ) # Shape: (2, 1, 2, 2) + + self.assertTensorEqual(output["pixel_values"], expected_pixel_values) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index d820857de1..cfb135545e 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -79,7 +79,7 @@ @dataclass -class PreferenceCollator(DataCollatorMixin): +class DataCollatorForPreference(DataCollatorMixin): """ Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch if they are not all of the same length. @@ -92,8 +92,8 @@ class PreferenceCollator(DataCollatorMixin): Examples: ```python - >>> from trl import PreferenceCollator - >>> collator = PreferenceCollator(pad_token_id=0) + >>> from trl import DataCollatorForPreference + >>> collator = DataCollatorForPreference(pad_token_id=0) >>> examples = [ ... {"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6]}, ... {"prompt_input_ids": [7, 8], "chosen_input_ids": [9, 10], "rejected_input_ids": [11, 12, 13]} @@ -168,7 +168,7 @@ class DPOTrainer(Trainer): args (`DPOConfig`): The DPO config arguments to use for training. data_collator (`transformers.DataCollator`): - The data collator to use for training. If None is specified, the default data collator (`PreferenceCollator`) will be used + The data collator to use for training. If None is specified, the default data collator (`DataCollatorForPreference`) will be used which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. train_dataset (`datasets.Dataset`): The dataset to use for training. @@ -374,7 +374,7 @@ def make_inputs_require_grad(module, input, output): ) if data_collator is None: - data_collator = PreferenceCollator(pad_token_id=self.padding_value) + data_collator = DataCollatorForPreference(pad_token_id=self.padding_value) # Disable dropout in the model and reference model if args.disable_dropout: @@ -684,7 +684,7 @@ def _set_signature_columns_if_needed(self): # If `self.args.remove_unused_columns` is True, non-signature columns are removed. # By default, this method sets `self._signature_columns` to the model's expected inputs. # In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work. - # Instead, we set them to the columns expected by `PreferenceCollator`, hence the override. + # Instead, we set them to the columns expected by `DataCollatorForPreference`, hence the override. if self._signature_columns is None: self._signature_columns = [ "prompt_input_ids",