Skip to content

Commit

Permalink
dpo_trainer gather metrics across ranks before logging
Browse files Browse the repository at this point in the history
according to huggingface#2468
  • Loading branch information
zhc7 authored Dec 13, 2024
1 parent ca850be commit 1e5df8c
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,7 +1424,11 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
if isinstance(metrics[0], torch.Tensor):
gathered = self._nested_gather([m.cuda() for m in metrics])
metrics = [g.mean() for g in gathered]
meaned = torch.tensor(metrics).mean()
logs[key] = meaned.item()
del self._stored_metrics[train_eval]

if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
Expand Down

0 comments on commit 1e5df8c

Please sign in to comment.