diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 556a6f3444..1e6e8e67ad 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1139,7 +1139,7 @@ class DPOVisionTrainerTester(unittest.TestCase): ("trl-internal-testing/tiny-Idefics2ForConditionalGeneration",), # ("trl-internal-testing/tiny-PaliGemmaForConditionalGeneration",), ("trl-internal-testing/tiny-LlavaForConditionalGeneration",), - # ("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",), + ("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",), ] ) def test_vdpo_trainer(self, model_id): @@ -1211,7 +1211,10 @@ def test_vdpo_trainer(self, model_id): for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - if model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and ( + if model_id in [ + "trl-internal-testing/tiny-LlavaForConditionalGeneration", + "trl-internal-testing/tiny-LlavaNextForConditionalGeneration", + ] and ( n.startswith("vision_tower.vision_model.encoder.layers.1") or n == "vision_tower.vision_model.post_layernorm.weight" ): diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 4be0665456..4e9cfb2d66 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -143,6 +143,8 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d output["pixel_values"] = pad(pixel_values, padding_value=0.0) if "pixel_attention_mask" in examples[0]: output["pixel_attention_mask"] = pad(pixel_attention_mask, padding_value=0) + if "image_sizes" in examples[0]: + output["image_sizes"] = torch.tensor([example["image_sizes"] for example in examples]) return output @@ -645,6 +647,8 @@ def process_row(features, processing_class, max_prompt_length, max_completion_le if "pixel_attention_mask" in processed_features: output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0] + if "image_sizes" in processed_features: + output["image_sizes"] = processed_features["image_sizes"][0] return output @@ -685,7 +689,7 @@ def _set_signature_columns_if_needed(self): # In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work. # Instead, we set them to the columns expected by `DPODataCollatorWithPadding`, hence the override. if self._signature_columns is None: - self._signature_columns = ["prompt_input_ids", "chosen_input_ids", "rejected_input_ids"] + self._signature_columns = ["prompt_input_ids", "chosen_input_ids", "rejected_input_ids", "image_sizes"] def get_train_dataloader(self) -> DataLoader: """ @@ -855,6 +859,8 @@ def concatenated_inputs( output["pixel_attention_mask"] = torch.cat( [batch["pixel_attention_mask"], batch["pixel_attention_mask"]], dim=0 ) + if "image_sizes" in batch: + output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0) # Concatenate the chosen and rejected completions max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) @@ -1078,6 +1084,8 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] if "pixel_attention_mask" in concatenated_batch: model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] + if "image_sizes" in concatenated_batch: + model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] prompt_input_ids = concatenated_batch["prompt_input_ids"] prompt_attention_mask = concatenated_batch["prompt_attention_mask"]