diff --git a/recipes/full_dpo_distributed.py b/recipes/full_dpo_distributed.py index f6d2d21e4..556eaefa3 100644 --- a/recipes/full_dpo_distributed.py +++ b/recipes/full_dpo_distributed.py @@ -865,14 +865,25 @@ def train(self) -> None: # clean up before training begins training.cleanup_before_training() - _, rank = get_world_size_and_rank() + world_size, rank = get_world_size_and_rank() # zero out the gradients before starting training self._optimizer.zero_grad() # Initialize tokens count and running loss (for grad accumulation) t0 = time.perf_counter() + + # Running metrics running_loss = 0 + running_metrics = { + "rewards/chosen": 0, + "rewards/rejected": 0, + "rewards/accuracies": 0, + "log_probs/chosen": 0, + "log_probs/rejected": 0, + "logits/chosen": 0, + "logits/rejected": 0, + } num_tokens = 0 self._profiler.start() @@ -922,16 +933,34 @@ def train(self) -> None: reference_chosen_log_probs, reference_rejected_log_probs, ) + reward_accuracies = (chosen_rewards > rejected_rewards).float() loss = loss.mean() - reward_accuracies = (chosen_rewards > rejected_rewards).float() loss = loss / self._gradient_accumulation_steps + + # Update running metrics running_loss += loss + scaling_factor = (1 / self._gradient_accumulation_steps) # to average out between grad_acc steps + running_metrics["rewards/chosen"] += scaling_factor * chosen_rewards.mean() + running_metrics["rewards/rejected"] += scaling_factor * rejected_rewards.mean() + running_metrics["rewards/accuracies"] += scaling_factor * reward_accuracies.mean() + running_metrics["log_probs/chosen"] += scaling_factor * policy_chosen_log_probs.detach().mean() + running_metrics["log_probs/rejected"] += scaling_factor * policy_rejected_log_probs.detach().mean() + running_metrics["logits/chosen"] += scaling_factor * policy_chosen_logits_mean + running_metrics["logits/rejected"] += scaling_factor * policy_rejected_logits_mean + loss.backward() # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: + # Accumulate running metrics across all devices + torch.distributed.all_reduce(running_loss) + torch.distributed.all_reduce(num_tokens) + + for key in running_metrics: + torch.distributed.all_reduce(running_metrics[key], op=torch.distributed.ReduceOp.AVG) + self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self._lr_scheduler.step() @@ -954,21 +983,15 @@ def train(self) -> None: log_dict = { "loss": loss_to_log, "lr": self._optimizer.param_groups[0]["lr"], - "tokens_per_second_per_gpu": num_tokens / time_per_step, - "rewards/chosen": chosen_rewards.mean().cpu(), - "rewards/rejected": rejected_rewards.mean().cpu(), - "rewards/accuracies": reward_accuracies.mean().cpu(), - "rewards/margins": (chosen_rewards - rejected_rewards) - .mean() - .cpu(), - "log_probs/rejected": policy_rejected_log_probs.detach() - .mean() - .cpu(), - "log_probs/chosen": policy_chosen_log_probs.detach() - .mean() - .cpu(), - "logits/rejected": policy_rejected_logits_mean.cpu(), - "logits/chosen": policy_chosen_logits_mean.cpu(), + "tokens_per_second_per_gpu": num_tokens / (time_per_step * world_size), + "rewards/chosen": running_metrics["rewards/chosen"].cpu(), + "rewards/rejected": running_metrics["rewards/rejected"].cpu(), + "rewards/accuracies": running_metrics["rewards/accuracies"].cpu(), + "rewards/margins": (running_metrics["rewards/chosen"] - running_metrics["rewards/rejected"]).cpu(), + "log_probs/chosen": running_metrics["log_probs/chosen"].cpu(), + "log_probs/rejected": running_metrics["log_probs/rejected"].cpu(), + "logits/chosen": running_metrics["logits/chosen"].cpu(), + "logits/rejected": running_metrics["logits/rejected"].cpu(), } if self._log_peak_memory_stats: log_dict.update( @@ -981,7 +1004,9 @@ def train(self) -> None: # Reset running stats for the next step running_loss = 0 + running_metrics = {key: 0 for key in running_metrics} num_tokens = 0 + t0 = time.perf_counter() # Step profiler diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index d54adc2cf..16d25cef7 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -651,14 +651,25 @@ def train(self) -> None: # clean up before training begins training.cleanup_before_training() - _, rank = utils.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() # zero out the gradients before starting training self._optimizer.zero_grad() # Initialize tokens count and running loss (for grad accumulation) t0 = time.perf_counter() + + # Running metrics running_loss = 0 + running_metrics = { + "rewards/chosen": 0, + "rewards/rejected": 0, + "rewards/accuracies": 0, + "log_probs/chosen": 0, + "log_probs/rejected": 0, + "logits/chosen": 0, + "logits/rejected": 0, + } num_tokens = 0 # self.epochs_run should be non-zero when we're resuming from a checkpoint @@ -706,16 +717,34 @@ def train(self) -> None: reference_chosen_log_probs, reference_rejected_log_probs, ) + reward_accuracies = (chosen_rewards > rejected_rewards).float() loss = loss.mean() - reward_accuracies = (chosen_rewards > rejected_rewards).float() loss = loss / self._gradient_accumulation_steps + + # Update running metrics running_loss += loss + scaling_factor = (1 / self._gradient_accumulation_steps) # to average out between grad_acc steps + running_metrics["rewards/chosen"] += scaling_factor * chosen_rewards.mean() + running_metrics["rewards/rejected"] += scaling_factor * rejected_rewards.mean() + running_metrics["rewards/accuracies"] += scaling_factor * reward_accuracies.mean() + running_metrics["log_probs/chosen"] += scaling_factor * policy_chosen_log_probs.detach().mean() + running_metrics["log_probs/rejected"] += scaling_factor * policy_rejected_log_probs.detach().mean() + running_metrics["logits/chosen"] += scaling_factor * policy_chosen_logits_mean + running_metrics["logits/rejected"] += scaling_factor * policy_rejected_logits_mean + loss.backward() # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: + # Accumulate running metrics across all devices + torch.distributed.all_reduce(running_loss) + torch.distributed.all_reduce(num_tokens) + + for key in running_metrics: + torch.distributed.all_reduce(running_metrics[key], op=torch.distributed.ReduceOp.AVG) + self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self._lr_scheduler.step() @@ -738,21 +767,15 @@ def train(self) -> None: log_dict = { "loss": loss_to_log, "lr": self._optimizer.param_groups[0]["lr"], - "tokens_per_second_per_gpu": num_tokens / time_per_step, - "rewards/chosen": chosen_rewards.mean().cpu(), - "rewards/rejected": rejected_rewards.mean().cpu(), - "rewards/accuracies": reward_accuracies.mean().cpu(), - "rewards/margins": (chosen_rewards - rejected_rewards) - .mean() - .cpu(), - "log_probs/rejected": policy_rejected_log_probs.detach() - .mean() - .cpu(), - "log_probs/chosen": policy_chosen_log_probs.detach() - .mean() - .cpu(), - "logits/rejected": policy_rejected_logits_mean.cpu(), - "logits/chosen": policy_chosen_logits_mean.cpu(), + "tokens_per_second_per_gpu": num_tokens / (time_per_step * world_size), + "rewards/chosen": running_metrics["rewards/chosen"].cpu(), + "rewards/rejected": running_metrics["rewards/rejected"].cpu(), + "rewards/accuracies": running_metrics["rewards/accuracies"].cpu(), + "rewards/margins": (running_metrics["rewards/chosen"] - running_metrics["rewards/rejected"]).cpu(), + "log_probs/chosen": running_metrics["log_probs/chosen"].cpu(), + "log_probs/rejected": running_metrics["log_probs/rejected"].cpu(), + "logits/chosen": running_metrics["logits/chosen"].cpu(), + "logits/rejected": running_metrics["logits/rejected"].cpu(), } if self._log_peak_memory_stats: log_dict.update( @@ -765,7 +788,9 @@ def train(self) -> None: # Reset running stats for the next step running_loss = 0 + running_metrics = {key: 0 for key in running_metrics} num_tokens = 0 + t0 = time.perf_counter() self.epochs_run += 1