From 8febff8000da9dafa7dba5ac03a7e32fb3f46d56 Mon Sep 17 00:00:00 2001 From: zhc7 Date: Sat, 14 Dec 2024 01:26:17 +0800 Subject: [PATCH] dpo_trainer gather metrics across ranks before logging according to https://github.com/huggingface/trl/issues/2468 --- trl/trainer/dpo_trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 7ed0ac387f..f26ae9b4b2 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1424,7 +1424,11 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non train_eval = "train" if "loss" in logs else "eval" # Add averaged stored metrics to logs for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() + if isinstance(metrics[0], torch.Tensor): + gathered = self._nested_gather([m.cuda() for m in metrics]) + metrics = [g.mean() for g in gathered] + meaned = torch.tensor(metrics).mean() + logs[key] = meaned.item() del self._stored_metrics[train_eval] if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):