diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index c30e2b6c2e..3f4ee1503b 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -472,8 +472,8 @@ def prediction_step( # logits for the chosen and rejected samples from model logits_dict = { - "logits_test/chosen": metrics["logits_test/chosen"], - "logits_test/rejected": metrics["logits_test/rejected"], + "eval_logits/chosen": metrics["eval_logits/chosen"], + "eval_logits/rejected": metrics["eval_logits/rejected"], } logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) logits = torch.stack(logits).mean(axis=1)