diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 50e72ce0e5..1a38c7f41f 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -556,7 +556,7 @@ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[ """Generate samples from the model and reference model for the given batch of inputs.""" policy_output = model.generate( - batch["prompt_input_ids"], + input_ids=batch["prompt_input_ids"], attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, do_sample=True,