Skip to content

Commit 3d745a2

Browse files
krammnicMark Obozov
and
Mark Obozov
authored
Custom DPO losses support (#2427)
Co-authored-by: Mark Obozov <markobozov@MacBook-Pro-Mark.local>
1 parent 8bf8647 commit 3d745a2

File tree

8 files changed

+160
-107
lines changed

8 files changed

+160
-107
lines changed

docs/source/recipes/dpo.rst

+22
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,28 @@ To use any of these, simply use the ``loss`` config entry or flag through the :r
5959
loss=torchtune.modules.loss.RSOLoss \
6060
gamma=0.5
6161
62+
Also, you can pass your custom loss in our recipe. Note that its `forward` method should align with the following signature:
63+
64+
.. code-block:: python
65+
66+
def forward(self, policy_inputs: ChosenRejectedOutputs, reference_inputs: ChosenRejectedOutputs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
67+
...
68+
69+
Here, `ChosenRejectedOutputs` is a dataclass obtained from `concatenated_forward``:
70+
71+
.. code-block:: python
72+
73+
@dataclass
74+
class ChosenRejectedOutputs:
75+
chosen_logps: torch.Tensor
76+
rejected_logps: torch.Tensor
77+
chosen_logits: torch.Tensor
78+
rejected_logits: torch.Tensor
79+
80+
If this is not sufficient and you need to compute additional values from the logits, you can modify `concatenated_forward` directly. To do this, use `tune cp` to copy the desired recipe, and don’t forget to use your own dataclass!
81+
82+
Refer to the TRL library for reference implementations of the desired losses. In particular, you may find useful loss calculations in trainers.
83+
6284
For a deeper understanding of the different levers you can pull when using this recipe,
6385
see our documentation for the different PEFT training paradigms we support:
6486

recipes/full_dpo_distributed.py

+29-25
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, padded_collate_dpo
2121
from torchtune.datasets import ConcatDataset
2222
from torchtune.recipe_interfaces import FTRecipeInterface
23+
from torchtune.rlhf import ChosenRejectedOutputs
2324
from torchtune.training import disable_dropout, DummyProfiler, PROFILER_KEY
2425
from torchtune.training.lr_schedulers import get_lr
2526
from torchtune.utils import get_world_size_and_rank
@@ -797,7 +798,7 @@ def concatenated_forward(
797798
model: nn.Module,
798799
batch: Tuple[torch.Tensor, torch.Tensor],
799800
activations_handling: Optional[bool] = True,
800-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
801+
) -> ChosenRejectedOutputs:
801802
"""
802803
Run forward pass of the model with chosen and rejected samples concatenated.
803804
@@ -806,7 +807,7 @@ def concatenated_forward(
806807
batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels.
807808
808809
Returns:
809-
Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits.
810+
Dataclass of chosen log probs, rejected log probs, chosen logits, rejected logits.
810811
"""
811812
concatenated_input_ids, concatenated_labels = batch
812813
concatenated_input_ids = concatenated_input_ids.to(self._device)
@@ -836,7 +837,9 @@ def concatenated_forward(
836837
chosen_logits = all_logits[:len_chosen]
837838
rejected_logits = all_logits[len_chosen:]
838839

839-
return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits)
840+
return ChosenRejectedOutputs(
841+
chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits
842+
)
840843

841844
def train(self) -> None:
842845
"""
@@ -884,36 +887,35 @@ def train(self) -> None:
884887

885888
# batch is input_ids, labels
886889
num_tokens += torch.tensor(batch[0].numel())
887-
(
888-
policy_chosen_log_probs,
889-
policy_rejected_log_probs,
890-
policy_chosen_logits,
891-
policy_rejected_logits,
892-
) = self.concatenated_forward(self._model, batch)
890+
policy_chosen_rejected_outputs = self.concatenated_forward(
891+
self._model, batch
892+
)
893893

894-
policy_chosen_logits_mean = policy_chosen_logits.detach().mean()
895-
policy_rejected_logits_mean = policy_rejected_logits.detach().mean()
894+
policy_chosen_logits_mean = (
895+
policy_chosen_rejected_outputs.chosen_logits.detach().mean()
896+
)
897+
policy_rejected_logits_mean = (
898+
policy_chosen_rejected_outputs.rejected_logits.detach().mean()
899+
)
896900

897901
# deleting logits here helps reduce (peak) memory usage - we only need them for metric logging
898-
del policy_chosen_logits, policy_rejected_logits
902+
del (
903+
policy_chosen_rejected_outputs.chosen_logits,
904+
policy_chosen_rejected_outputs.rejected_logits,
905+
)
899906

900907
with torch.no_grad():
901-
(
902-
reference_chosen_log_probs,
903-
reference_rejected_log_probs,
904-
reference_chosen_logits,
905-
reference_rejected_logits,
906-
) = self.concatenated_forward(
908+
reference_chosen_rejected_outputs = self.concatenated_forward(
907909
self._ref_model, batch, activations_handling=False
908910
)
909911

910-
del reference_chosen_logits, reference_rejected_logits
912+
del (
913+
reference_chosen_rejected_outputs.chosen_logits,
914+
reference_chosen_rejected_outputs.rejected_logits,
915+
)
911916

912917
loss, chosen_rewards, rejected_rewards = self._loss_fn(
913-
policy_chosen_log_probs,
914-
policy_rejected_log_probs,
915-
reference_chosen_log_probs,
916-
reference_rejected_log_probs,
918+
policy_chosen_rejected_outputs, reference_chosen_rejected_outputs
917919
)
918920
reward_accuracies = (chosen_rewards > rejected_rewards).float()
919921

@@ -936,10 +938,12 @@ def train(self) -> None:
936938
scaling_factor * reward_accuracies.mean()
937939
)
938940
running_metrics["log_probs/chosen"] += (
939-
scaling_factor * policy_chosen_log_probs.detach().mean()
941+
scaling_factor
942+
* policy_chosen_rejected_outputs.chosen_logps.detach().mean()
940943
)
941944
running_metrics["log_probs/rejected"] += (
942-
scaling_factor * policy_rejected_log_probs.detach().mean()
945+
scaling_factor
946+
* policy_chosen_rejected_outputs.rejected_logps.detach().mean()
943947
)
944948
running_metrics["logits/chosen"] += (
945949
scaling_factor * policy_chosen_logits_mean

recipes/lora_dpo_distributed.py

+28-24
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
validate_missing_and_unexpected_for_lora,
3434
)
3535
from torchtune.recipe_interfaces import FTRecipeInterface
36+
from torchtune.rlhf import ChosenRejectedOutputs
3637
from tqdm import tqdm
3738

3839
log = utils.get_logger("DEBUG")
@@ -614,7 +615,7 @@ def save_checkpoint(
614615

615616
def concatenated_forward(
616617
self, model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor]
617-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
618+
) -> ChosenRejectedOutputs:
618619
"""
619620
Run forward pass of the model with chosen and rejected samples concatenated.
620621
@@ -623,7 +624,7 @@ def concatenated_forward(
623624
batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels.
624625
625626
Returns:
626-
Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits.
627+
Dataclass of chosen log probs, rejected log probs, chosen logits, rejected logits.
627628
"""
628629
concatenated_input_ids, concatenated_labels = batch
629630
concatenated_input_ids = concatenated_input_ids.to(self._device)
@@ -643,7 +644,9 @@ def concatenated_forward(
643644
chosen_logits = all_logits[:len_chosen]
644645
rejected_logits = all_logits[len_chosen:]
645646

646-
return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits)
647+
return ChosenRejectedOutputs(
648+
chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits
649+
)
647650

648651
def train(self) -> None:
649652
"""
@@ -690,31 +693,30 @@ def train(self) -> None:
690693
# batch is input_ids, labels
691694
num_tokens += torch.tensor(batch[0].numel())
692695

693-
(
694-
policy_chosen_log_probs,
695-
policy_rejected_log_probs,
696-
policy_chosen_logits,
697-
policy_rejected_logits,
698-
) = self.concatenated_forward(self._model, batch)
696+
policy_chosen_rejected_outputs = self.concatenated_forward(
697+
self._model, batch
698+
)
699699

700-
policy_chosen_logits_mean = policy_chosen_logits.detach().mean()
701-
policy_rejected_logits_mean = policy_rejected_logits.detach().mean()
700+
policy_chosen_logits_mean = (
701+
policy_chosen_rejected_outputs.chosen_logits.detach().mean()
702+
)
703+
policy_rejected_logits_mean = (
704+
policy_chosen_rejected_outputs.rejected_logits.detach().mean()
705+
)
702706

703707
# deleting logits here helps reduce (peak) memory usage - we only need them for metric logging
704-
del policy_chosen_logits, policy_rejected_logits
708+
del (
709+
policy_chosen_rejected_outputs.chosen_logits,
710+
policy_chosen_rejected_outputs.rejected_logits,
711+
)
705712

706713
with torch.no_grad(), disable_adapter(self._model):
707-
(
708-
reference_chosen_log_probs,
709-
reference_rejected_log_probs,
710-
_,
711-
_,
712-
) = self.concatenated_forward(self._model, batch)
714+
reference_chosen_rejected_outputs = self.concatenated_forward(
715+
self._model, batch
716+
)
713717
loss, chosen_rewards, rejected_rewards = self._loss_fn(
714-
policy_chosen_log_probs,
715-
policy_rejected_log_probs,
716-
reference_chosen_log_probs,
717-
reference_rejected_log_probs,
718+
policy_chosen_rejected_outputs,
719+
reference_chosen_rejected_outputs,
718720
)
719721
reward_accuracies = (chosen_rewards > rejected_rewards).float()
720722

@@ -737,10 +739,12 @@ def train(self) -> None:
737739
scaling_factor * reward_accuracies.mean()
738740
)
739741
running_metrics["log_probs/chosen"] += (
740-
scaling_factor * policy_chosen_log_probs.detach().mean()
742+
scaling_factor
743+
* policy_chosen_rejected_outputs.chosen_logps.detach().mean()
741744
)
742745
running_metrics["log_probs/rejected"] += (
743-
scaling_factor * policy_rejected_log_probs.detach().mean()
746+
scaling_factor
747+
* policy_chosen_rejected_outputs.rejected_logps.detach().mean()
744748
)
745749
running_metrics["logits/chosen"] += (
746750
scaling_factor * policy_chosen_logits_mean

recipes/lora_dpo_single_device.py

+26-24
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
validate_missing_and_unexpected_for_lora,
3131
)
3232
from torchtune.recipe_interfaces import FTRecipeInterface
33+
from torchtune.rlhf import ChosenRejectedOutputs
3334

3435
from tqdm import tqdm
3536

@@ -472,7 +473,7 @@ def save_checkpoint(self, epoch: int) -> None:
472473

473474
def concatenated_forward(
474475
self, model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor]
475-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
476+
) -> ChosenRejectedOutputs:
476477
"""
477478
Run forward pass of the model with chosen and rejected samples concatenated.
478479
@@ -481,7 +482,7 @@ def concatenated_forward(
481482
batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels.
482483
483484
Returns:
484-
Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits.
485+
Dataclass of chosen log probs, rejected log probs, chosen logits, rejected logits.
485486
"""
486487
concatenated_input_ids, concatenated_labels = batch
487488
concatenated_input_ids = concatenated_input_ids.to(self._device)
@@ -501,7 +502,9 @@ def concatenated_forward(
501502
chosen_logits = all_logits[:len_chosen]
502503
rejected_logits = all_logits[len_chosen:]
503504

504-
return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits)
505+
return ChosenRejectedOutputs(
506+
chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits
507+
)
505508

506509
def train(self) -> None:
507510
"""
@@ -533,31 +536,30 @@ def train(self) -> None:
533536

534537
# batch is input_ids, labels
535538
num_tokens += batch[0].numel()
536-
(
537-
policy_chosen_log_probs,
538-
policy_rejected_log_probs,
539-
policy_chosen_logits,
540-
policy_rejected_logits,
541-
) = self.concatenated_forward(self._model, batch)
539+
policy_chosen_rejected_outputs = self.concatenated_forward(
540+
self._model, batch
541+
)
542542

543-
policy_chosen_logits_mean = policy_chosen_logits.detach().mean()
544-
policy_rejected_logits_mean = policy_rejected_logits.detach().mean()
543+
policy_chosen_logits_mean = (
544+
policy_chosen_rejected_outputs.chosen_logits.detach().mean()
545+
)
546+
policy_rejected_logits_mean = (
547+
policy_chosen_rejected_outputs.rejected_logits.detach().mean()
548+
)
545549

546550
# deleting logits here helps reduce (peak) memory usage - we only need them for metric logging
547-
del policy_chosen_logits, policy_rejected_logits
551+
del (
552+
policy_chosen_rejected_outputs.chosen_logits,
553+
policy_chosen_rejected_outputs.rejected_logits,
554+
)
548555

549556
with torch.no_grad(), disable_adapter(self._model):
550-
(
551-
reference_chosen_log_probs,
552-
reference_rejected_log_probs,
553-
_,
554-
_,
555-
) = self.concatenated_forward(self._model, batch)
557+
reference_chosen_rejected_outputs = self.concatenated_forward(
558+
self._model, batch
559+
)
556560
loss, chosen_rewards, rejected_rewards = self._loss_fn(
557-
policy_chosen_log_probs,
558-
policy_rejected_log_probs,
559-
reference_chosen_log_probs,
560-
reference_rejected_log_probs,
561+
policy_chosen_rejected_outputs,
562+
reference_chosen_rejected_outputs,
561563
)
562564

563565
loss = loss.mean()
@@ -596,10 +598,10 @@ def train(self) -> None:
596598
"rewards/margins": (chosen_rewards - rejected_rewards)
597599
.mean()
598600
.cpu(),
599-
"log_probs/rejected": policy_rejected_log_probs.detach()
601+
"log_probs/rejected": policy_chosen_rejected_outputs.rejected_logps.detach()
600602
.mean()
601603
.cpu(),
602-
"log_probs/chosen": policy_chosen_log_probs.detach()
604+
"log_probs/chosen": policy_chosen_rejected_outputs.chosen_logps.detach()
603605
.mean()
604606
.cpu(),
605607
"logits/rejected": policy_rejected_logits_mean.cpu(),

tests/torchtune/rlhf/loss/test_dpo_loss.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pytest
88
import torch
9+
from torchtune.rlhf._types import ChosenRejectedOutputs
910
from torchtune.rlhf.loss import DPOLoss, RSOLoss
1011

1112

@@ -39,11 +40,16 @@ def loss_inputs(self):
3940
ref_chosen_logprobs = torch.tensor([-0.5, -10.1, -0.1])
4041
ref_rejected_logprobs = torch.tensor([-0.1, -20.1, -0.1])
4142

42-
return (
43+
return ChosenRejectedOutputs(
4344
policy_chosen_logprobs,
4445
policy_rejected_logprobs,
46+
torch.tensor(0),
47+
torch.tensor(0),
48+
), ChosenRejectedOutputs(
4549
ref_chosen_logprobs,
4650
ref_rejected_logprobs,
51+
torch.tensor(0),
52+
torch.tensor(0),
4753
)
4854

4955
def test_dpo_loss(self, dpo_loss, loss_inputs):

torchtune/rlhf/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
from ._types import PPOStats, Trajectory
8+
from ._types import ChosenRejectedOutputs, PPOStats, Trajectory
99

1010
from .rewards import (
1111
estimate_advantages,
@@ -39,4 +39,5 @@
3939
"PPOStats",
4040
"get_batch_log_probs",
4141
"Trajectory",
42+
"ChosenRejectedOutputs",
4243
]

0 commit comments

Comments
 (0)