Skip to content

Commit

Permalink
Add a variant of CPO, SimPO (#1703)
Browse files Browse the repository at this point in the history
* add a variant of cpo: simpo

* correct cpo-simpo loss

* avoid 0 int error in logging

* add simpo description

* Update trl/trainer/cpo_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* fix formatting

* add test for simpo

* Update docs/source/cpo_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add a docstring for simpogamma

* move simpo description to the above docstring

* change simpo description in the doc

* formatting

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
  • Loading branch information
fe1ixxu and kashif authored Jun 6, 2024
1 parent 3eb9ccb commit b8b972f
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 5 deletions.
2 changes: 2 additions & 0 deletions docs/source/cpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss

The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the CPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike CPO which is summed only).

The [SimPO](https://arxiv.org/abs/2405.14734) is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on`loss_type="simpo"` in the `CPOConfig`.


## Logging

Expand Down
2 changes: 2 additions & 0 deletions tests/test_cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def _init_dummy_dataset(self):
["t5", "hinge"],
["gpt2", "ipo"],
["t5", "ipo"],
["gpt2", "simpo"],
["t5", "simpo"],
]
)
def test_cpo_trainer(self, name, loss_type):
Expand Down
5 changes: 4 additions & 1 deletion trl/trainer/cpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class CPOConfig(TrainingArguments):
The type of loss to use. This argument is required if you want to use the default data collator.
label_pad_token_id (`int`, defaults to `-100`):
The label pad token id. This argument is required if you want to use the default data collator.
simpo_gamma (`float`, defaults to `0.5`):
A target reward margin for the SimPO loss, used only when the "simpo" option is enabled.
padding_value (`int`, defaults to `None`):
The padding value if it is different to the tokenizer's pad_token_id.
truncation_mode (`str`, defaults to `keep_end`):
Expand All @@ -64,8 +66,9 @@ class CPOConfig(TrainingArguments):

beta: float = 0.1
label_smoothing: float = 0
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair", "simpo"] = "sigmoid"
disable_dropout: bool = True
simpo_gamma: float = 0.5

label_pad_token_id: int = -100
padding_value: int = None
Expand Down
24 changes: 20 additions & 4 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ def make_inputs_require_grad(module, input, output):
self.label_smoothing = args.label_smoothing
self.loss_type = args.loss_type

if args.loss_type == "simpo":
self.simpo_gamma = args.simpo_gamma

self._stored_metrics = defaultdict(lambda: defaultdict(list))

# Compute that only on the main process for faster data processing.
Expand Down Expand Up @@ -585,7 +588,16 @@ def cpo_loss(
# The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative CPO loss.
if self.loss_type == "sigmoid":

if self.loss_type == "simpo":
gamma_logratios = self.simpo_gamma / self.beta
logits = logits - gamma_logratios
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "sigmoid":
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
Expand All @@ -598,7 +610,7 @@ def cpo_loss(
losses = (logits - 1 / (2 * self.beta)) ** 2
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'simpo']"
)

chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
Expand Down Expand Up @@ -691,12 +703,16 @@ def cross_entropy_loss(logits, labels):
return loss

labels = concatenated_batch["concatenated_labels"].clone()
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

if self.loss_type != "simpo":
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
else:
nll_loss = torch.tensor(0.0).to(self.accelerator.device)

all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=self.loss_type == "ipo",
average_log_prob=self.loss_type in ["ipo", "simpo"],
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
Expand Down

0 comments on commit b8b972f

Please sign in to comment.