Skip to content

Commit

Permalink
[DPO] Resolve logging for DPOTrainer (#570)
Browse files Browse the repository at this point in the history
* Resolve logging for DPOTrainer

* Ensure the WandB logger correctly prefixes all logs

* Run pre-commit

Whoops, hadn't run `pre-commit install` yet
  • Loading branch information
tomaarsen authored Jul 26, 2023
1 parent d78d917 commit b3c2e73
Showing 1 changed file with 38 additions and 17 deletions.
55 changes: 38 additions & 17 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from collections import defaultdict
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -155,6 +156,8 @@ def __init__(
self.beta = beta
self.ref_model = ref_model

self._stored_metrics = defaultdict(lambda: defaultdict(list))

super().__init__(
model,
args,
Expand Down Expand Up @@ -304,7 +307,7 @@ def get_batch_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_test: str = "train",
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
Expand All @@ -331,17 +334,15 @@ def get_batch_metrics(
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()

metrics[f"rewards_{train_test}/chosen"] = chosen_rewards.cpu().numpy().mean()
metrics[f"rewards_{train_test}/rejected"] = rejected_rewards.cpu().numpy().mean()
metrics[f"rewards_{train_test}/accuracies"] = reward_accuracies.cpu().numpy().mean()
metrics[f"rewards_{train_test}/margins"] = (chosen_rewards - rejected_rewards).cpu().numpy().mean()
metrics[f"logps_{train_test}/rejected"] = policy_rejected_logps.detach().cpu().numpy().mean()
metrics[f"logps_{train_test}/chosen"] = policy_chosen_logps.detach().cpu().numpy().mean()

metrics[f"logits_{train_test}/rejected"] = policy_rejected_logits.detach().cpu().numpy().mean()
metrics[f"logits_{train_test}/chosen"] = policy_chosen_logits.detach().cpu().numpy().mean()

metrics[f"loss/{train_test}"] = losses.detach().cpu().numpy().mean()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().numpy().mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().numpy().mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().numpy().mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().numpy().mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().numpy().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().numpy().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().numpy().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().numpy().mean()

return losses.mean(), metrics

Expand All @@ -356,11 +357,11 @@ def compute_loss(
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
loss, metrics = self.get_batch_metrics(model, inputs, train_test="train")
loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")

# force log the metrics
if self.accelerator.is_main_process:
self.log_metrics("train", metrics)
self.store_metrics(metrics, train_eval="train")

if return_outputs:
return (loss, metrics)
Expand Down Expand Up @@ -412,11 +413,11 @@ def prediction_step(
ignore_keys = []

with torch.no_grad():
loss, metrics = self.get_batch_metrics(model, inputs, train_test="test")
loss, metrics = self.get_batch_metrics(model, inputs, train_eval="eval")

# force log the metrics
if self.accelerator.is_main_process:
self.log_metrics("test", metrics)
self.store_metrics(metrics, train_eval="eval")

if prediction_loss_only:
return (loss.detach(), None, None)
Expand All @@ -431,3 +432,23 @@ def prediction_step(
labels = torch.zeros(logits.shape[0])

return (loss.detach(), logits, labels)

def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)

def log(self, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)

0 comments on commit b3c2e73

Please sign in to comment.