Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DPO] Resolve logging for DPOTrainer #570

Merged
merged 3 commits into from
Jul 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from collections import defaultdict
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -155,6 +156,8 @@ def __init__(
self.beta = beta
self.ref_model = ref_model

self._stored_metrics = defaultdict(lambda: defaultdict(list))

super().__init__(
model,
args,
Expand Down Expand Up @@ -304,7 +307,7 @@ def get_batch_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_test: str = "train",
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
Expand All @@ -331,17 +334,15 @@ def get_batch_metrics(
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()

metrics[f"rewards_{train_test}/chosen"] = chosen_rewards.cpu().numpy().mean()
metrics[f"rewards_{train_test}/rejected"] = rejected_rewards.cpu().numpy().mean()
metrics[f"rewards_{train_test}/accuracies"] = reward_accuracies.cpu().numpy().mean()
metrics[f"rewards_{train_test}/margins"] = (chosen_rewards - rejected_rewards).cpu().numpy().mean()
metrics[f"logps_{train_test}/rejected"] = policy_rejected_logps.detach().cpu().numpy().mean()
metrics[f"logps_{train_test}/chosen"] = policy_chosen_logps.detach().cpu().numpy().mean()

metrics[f"logits_{train_test}/rejected"] = policy_rejected_logits.detach().cpu().numpy().mean()
metrics[f"logits_{train_test}/chosen"] = policy_chosen_logits.detach().cpu().numpy().mean()

metrics[f"loss/{train_test}"] = losses.detach().cpu().numpy().mean()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().numpy().mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().numpy().mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().numpy().mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().numpy().mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().numpy().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().numpy().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().numpy().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().numpy().mean()

return losses.mean(), metrics

Expand All @@ -356,11 +357,11 @@ def compute_loss(
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
loss, metrics = self.get_batch_metrics(model, inputs, train_test="train")
loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")

# force log the metrics
if self.accelerator.is_main_process:
self.log_metrics("train", metrics)
self.store_metrics(metrics, train_eval="train")

if return_outputs:
return (loss, metrics)
Expand Down Expand Up @@ -412,11 +413,11 @@ def prediction_step(
ignore_keys = []

with torch.no_grad():
loss, metrics = self.get_batch_metrics(model, inputs, train_test="test")
loss, metrics = self.get_batch_metrics(model, inputs, train_eval="eval")

# force log the metrics
if self.accelerator.is_main_process:
self.log_metrics("test", metrics)
self.store_metrics(metrics, train_eval="eval")

if prediction_loss_only:
return (loss.detach(), None, None)
Expand All @@ -431,3 +432,23 @@ def prediction_step(
labels = torch.zeros(logits.shape[0])

return (loss.detach(), logits, labels)

def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)

def log(self, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.

Args:
logs (`Dict[str, float]`):
The values to log.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)