Skip to content

Commit

Permalink
RDPO fix nll loss (#1705)
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif authored Jun 7, 2024
1 parent b8b972f commit 5bcb8ad
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,7 @@ def get_batch_logps(

def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
Expand Down Expand Up @@ -1158,7 +1158,23 @@ def concatenated_forward(
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
chosen_logps_avg = all_logps[:len_chosen] / size_completion[:len_chosen]

def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

labels = concatenated_batch["concatenated_labels"].clone()
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

if self.loss_type == "ipo":
all_logps = all_logps / size_completion
Expand All @@ -1169,7 +1185,7 @@ def concatenated_forward(
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps_avg)
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)

def get_batch_loss_metrics(
self,
Expand All @@ -1185,7 +1201,7 @@ def get_batch_loss_metrics(
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_chosen_logps_avg,
policy_nll_loss,
) = self.concatenated_forward(model, batch)

# if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
Expand Down Expand Up @@ -1225,7 +1241,7 @@ def get_batch_loss_metrics(
reward_accuracies = (chosen_rewards > rejected_rewards).float()

if self.args.rpo_alpha is not None:
losses = losses * self.args.rpo_alpha - policy_chosen_logps_avg
losses = losses * self.args.rpo_alpha + policy_nll_loss

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
Expand All @@ -1236,6 +1252,8 @@ def get_batch_loss_metrics(
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
if self.args.rpo_alpha is not None:
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()

return losses.mean(), metrics

Expand Down

0 comments on commit 5bcb8ad

Please sign in to comment.