Skip to content

Commit

Permalink
Merge pull request #2 from bogdansalyp/fix/running_metrics_and_sync_l…
Browse files Browse the repository at this point in the history
…ora_dpo

fix: Running metrics and tokens_per_second_per_gpu fixes for DPO recipes
  • Loading branch information
sam-pi authored Feb 1, 2025
2 parents ebf288a + 1a673df commit 16821c4
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 34 deletions.
59 changes: 42 additions & 17 deletions recipes/full_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,14 +865,25 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

_, rank = get_world_size_and_rank()
world_size, rank = get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()

# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()

# Running metrics
running_loss = 0
running_metrics = {
"rewards/chosen": 0,
"rewards/rejected": 0,
"rewards/accuracies": 0,
"log_probs/chosen": 0,
"log_probs/rejected": 0,
"logits/chosen": 0,
"logits/rejected": 0,
}
num_tokens = 0

self._profiler.start()
Expand Down Expand Up @@ -922,16 +933,34 @@ def train(self) -> None:
reference_chosen_log_probs,
reference_rejected_log_probs,
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()

loss = loss.mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()

loss = loss / self._gradient_accumulation_steps

# Update running metrics
running_loss += loss
scaling_factor = (1 / self._gradient_accumulation_steps) # to average out between grad_acc steps
running_metrics["rewards/chosen"] += scaling_factor * chosen_rewards.mean()
running_metrics["rewards/rejected"] += scaling_factor * rejected_rewards.mean()
running_metrics["rewards/accuracies"] += scaling_factor * reward_accuracies.mean()
running_metrics["log_probs/chosen"] += scaling_factor * policy_chosen_log_probs.detach().mean()
running_metrics["log_probs/rejected"] += scaling_factor * policy_rejected_log_probs.detach().mean()
running_metrics["logits/chosen"] += scaling_factor * policy_chosen_logits_mean
running_metrics["logits/rejected"] += scaling_factor * policy_rejected_logits_mean

loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
# Accumulate running metrics across all devices
torch.distributed.all_reduce(running_loss)
torch.distributed.all_reduce(num_tokens)

for key in running_metrics:
torch.distributed.all_reduce(running_metrics[key], op=torch.distributed.ReduceOp.AVG)

self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
Expand All @@ -954,21 +983,15 @@ def train(self) -> None:
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
"rewards/chosen": chosen_rewards.mean().cpu(),
"rewards/rejected": rejected_rewards.mean().cpu(),
"rewards/accuracies": reward_accuracies.mean().cpu(),
"rewards/margins": (chosen_rewards - rejected_rewards)
.mean()
.cpu(),
"log_probs/rejected": policy_rejected_log_probs.detach()
.mean()
.cpu(),
"log_probs/chosen": policy_chosen_log_probs.detach()
.mean()
.cpu(),
"logits/rejected": policy_rejected_logits_mean.cpu(),
"logits/chosen": policy_chosen_logits_mean.cpu(),
"tokens_per_second_per_gpu": num_tokens / (time_per_step * world_size),
"rewards/chosen": running_metrics["rewards/chosen"].cpu(),
"rewards/rejected": running_metrics["rewards/rejected"].cpu(),
"rewards/accuracies": running_metrics["rewards/accuracies"].cpu(),
"rewards/margins": (running_metrics["rewards/chosen"] - running_metrics["rewards/rejected"]).cpu(),
"log_probs/chosen": running_metrics["log_probs/chosen"].cpu(),
"log_probs/rejected": running_metrics["log_probs/rejected"].cpu(),
"logits/chosen": running_metrics["logits/chosen"].cpu(),
"logits/rejected": running_metrics["logits/rejected"].cpu(),
}
if self._log_peak_memory_stats:
log_dict.update(
Expand All @@ -981,7 +1004,9 @@ def train(self) -> None:

# Reset running stats for the next step
running_loss = 0
running_metrics = {key: 0 for key in running_metrics}
num_tokens = 0

t0 = time.perf_counter()

# Step profiler
Expand Down
59 changes: 42 additions & 17 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,14 +651,25 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

_, rank = utils.get_world_size_and_rank()
world_size, rank = utils.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()

# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()

# Running metrics
running_loss = 0
running_metrics = {
"rewards/chosen": 0,
"rewards/rejected": 0,
"rewards/accuracies": 0,
"log_probs/chosen": 0,
"log_probs/rejected": 0,
"logits/chosen": 0,
"logits/rejected": 0,
}
num_tokens = 0

# self.epochs_run should be non-zero when we're resuming from a checkpoint
Expand Down Expand Up @@ -706,16 +717,34 @@ def train(self) -> None:
reference_chosen_log_probs,
reference_rejected_log_probs,
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()

loss = loss.mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()

loss = loss / self._gradient_accumulation_steps

# Update running metrics
running_loss += loss
scaling_factor = (1 / self._gradient_accumulation_steps) # to average out between grad_acc steps
running_metrics["rewards/chosen"] += scaling_factor * chosen_rewards.mean()
running_metrics["rewards/rejected"] += scaling_factor * rejected_rewards.mean()
running_metrics["rewards/accuracies"] += scaling_factor * reward_accuracies.mean()
running_metrics["log_probs/chosen"] += scaling_factor * policy_chosen_log_probs.detach().mean()
running_metrics["log_probs/rejected"] += scaling_factor * policy_rejected_log_probs.detach().mean()
running_metrics["logits/chosen"] += scaling_factor * policy_chosen_logits_mean
running_metrics["logits/rejected"] += scaling_factor * policy_rejected_logits_mean

loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
# Accumulate running metrics across all devices
torch.distributed.all_reduce(running_loss)
torch.distributed.all_reduce(num_tokens)

for key in running_metrics:
torch.distributed.all_reduce(running_metrics[key], op=torch.distributed.ReduceOp.AVG)

self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
Expand All @@ -738,21 +767,15 @@ def train(self) -> None:
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
"rewards/chosen": chosen_rewards.mean().cpu(),
"rewards/rejected": rejected_rewards.mean().cpu(),
"rewards/accuracies": reward_accuracies.mean().cpu(),
"rewards/margins": (chosen_rewards - rejected_rewards)
.mean()
.cpu(),
"log_probs/rejected": policy_rejected_log_probs.detach()
.mean()
.cpu(),
"log_probs/chosen": policy_chosen_log_probs.detach()
.mean()
.cpu(),
"logits/rejected": policy_rejected_logits_mean.cpu(),
"logits/chosen": policy_chosen_logits_mean.cpu(),
"tokens_per_second_per_gpu": num_tokens / (time_per_step * world_size),
"rewards/chosen": running_metrics["rewards/chosen"].cpu(),
"rewards/rejected": running_metrics["rewards/rejected"].cpu(),
"rewards/accuracies": running_metrics["rewards/accuracies"].cpu(),
"rewards/margins": (running_metrics["rewards/chosen"] - running_metrics["rewards/rejected"]).cpu(),
"log_probs/chosen": running_metrics["log_probs/chosen"].cpu(),
"log_probs/rejected": running_metrics["log_probs/rejected"].cpu(),
"logits/chosen": running_metrics["logits/chosen"].cpu(),
"logits/rejected": running_metrics["logits/rejected"].cpu(),
}
if self._log_peak_memory_stats:
log_dict.update(
Expand All @@ -765,7 +788,9 @@ def train(self) -> None:

# Reset running stats for the next step
running_loss = 0
running_metrics = {key: 0 for key in running_metrics}
num_tokens = 0

t0 = time.perf_counter()

self.epochs_run += 1
Expand Down

0 comments on commit 16821c4

Please sign in to comment.