From 4e847d6275071e9c512c64787f7a2b15154ac306 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 25 Jul 2023 13:49:12 +0200 Subject: [PATCH 1/3] Resolve logging for DPOTrainer --- trl/trainer/dpo_trainer.py | 53 ++++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 67f2bf9895..b93116cba8 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -12,8 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch import torch.nn as nn @@ -155,6 +156,8 @@ def __init__( self.beta = beta self.ref_model = ref_model + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + super().__init__( model, args, @@ -304,7 +307,7 @@ def get_batch_metrics( self, model, batch: Dict[str, Union[List, torch.LongTensor]], - train_test: str = "train", + train_eval: Literal["train", "eval"] = "train", ): """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} @@ -331,17 +334,15 @@ def get_batch_metrics( ) reward_accuracies = (chosen_rewards > rejected_rewards).float() - metrics[f"rewards_{train_test}/chosen"] = chosen_rewards.cpu().numpy().mean() - metrics[f"rewards_{train_test}/rejected"] = rejected_rewards.cpu().numpy().mean() - metrics[f"rewards_{train_test}/accuracies"] = reward_accuracies.cpu().numpy().mean() - metrics[f"rewards_{train_test}/margins"] = (chosen_rewards - rejected_rewards).cpu().numpy().mean() - metrics[f"logps_{train_test}/rejected"] = policy_rejected_logps.detach().cpu().numpy().mean() - metrics[f"logps_{train_test}/chosen"] = policy_chosen_logps.detach().cpu().numpy().mean() - - metrics[f"logits_{train_test}/rejected"] = policy_rejected_logits.detach().cpu().numpy().mean() - metrics[f"logits_{train_test}/chosen"] = policy_chosen_logits.detach().cpu().numpy().mean() + metrics[f"{train_eval}_rewards/chosen"] = chosen_rewards.cpu().numpy().mean() + metrics[f"{train_eval}_rewards/rejected"] = rejected_rewards.cpu().numpy().mean() + metrics[f"{train_eval}_rewards/accuracies"] = reward_accuracies.cpu().numpy().mean() + metrics[f"{train_eval}_rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().numpy().mean() + metrics[f"{train_eval}_logps/rejected"] = policy_rejected_logps.detach().cpu().numpy().mean() + metrics[f"{train_eval}_logps/chosen"] = policy_chosen_logps.detach().cpu().numpy().mean() - metrics[f"loss/{train_test}"] = losses.detach().cpu().numpy().mean() + metrics[f"{train_eval}_logits/rejected"] = policy_rejected_logits.detach().cpu().numpy().mean() + metrics[f"{train_eval}_logits/chosen"] = policy_chosen_logits.detach().cpu().numpy().mean() return losses.mean(), metrics @@ -356,11 +357,11 @@ def compute_loss( "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" ) - loss, metrics = self.get_batch_metrics(model, inputs, train_test="train") + loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train") # force log the metrics if self.accelerator.is_main_process: - self.log_metrics("train", metrics) + self.store_metrics(metrics, train_eval="train") if return_outputs: return (loss, metrics) @@ -412,11 +413,11 @@ def prediction_step( ignore_keys = [] with torch.no_grad(): - loss, metrics = self.get_batch_metrics(model, inputs, train_test="test") + loss, metrics = self.get_batch_metrics(model, inputs, train_eval="eval") # force log the metrics if self.accelerator.is_main_process: - self.log_metrics("test", metrics) + self.store_metrics(metrics, train_eval="eval") if prediction_loss_only: return (loss.detach(), None, None) @@ -431,3 +432,23 @@ def prediction_step( labels = torch.zeros(logits.shape[0]) return (loss.detach(), logits, labels) + + def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def log(self, logs: Dict[str, float]) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs) From 344e5d36520c45587b80adca7d042650afe8c949 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 25 Jul 2023 14:22:51 +0200 Subject: [PATCH 2/3] Ensure the WandB logger correctly prefixes all logs --- trl/trainer/dpo_trainer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index b93116cba8..c85670cb76 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -334,15 +334,15 @@ def get_batch_metrics( ) reward_accuracies = (chosen_rewards > rejected_rewards).float() - metrics[f"{train_eval}_rewards/chosen"] = chosen_rewards.cpu().numpy().mean() - metrics[f"{train_eval}_rewards/rejected"] = rejected_rewards.cpu().numpy().mean() - metrics[f"{train_eval}_rewards/accuracies"] = reward_accuracies.cpu().numpy().mean() - metrics[f"{train_eval}_rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().numpy().mean() - metrics[f"{train_eval}_logps/rejected"] = policy_rejected_logps.detach().cpu().numpy().mean() - metrics[f"{train_eval}_logps/chosen"] = policy_chosen_logps.detach().cpu().numpy().mean() - - metrics[f"{train_eval}_logits/rejected"] = policy_rejected_logits.detach().cpu().numpy().mean() - metrics[f"{train_eval}_logits/chosen"] = policy_chosen_logits.detach().cpu().numpy().mean() + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().numpy().mean() + metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().numpy().mean() + metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().numpy().mean() + metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().numpy().mean() + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().numpy().mean() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().numpy().mean() + metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().numpy().mean() + metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().numpy().mean() return losses.mean(), metrics From 4aa3fb033f12ac7616ec513649e25f4fa3b37e70 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 25 Jul 2023 14:31:03 +0200 Subject: [PATCH 3/3] Run pre-commit Whoops, hadn't run `pre-commit install` yet --- trl/trainer/dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index c85670cb76..af24451b53 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -12,8 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict import warnings +from collections import defaultdict from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch