Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: anchored pref optimization #1928

Merged
merged 12 commits into from
Aug 14, 2024
10 changes: 6 additions & 4 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper the

The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0).

The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. To use the loss set the `loss_type="exo"` in the [`DPOConfig`]. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large.
The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. To use the loss set the `loss_type="exo_pair"` in the [`DPOConfig`]. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large.

The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. To use the loss set the `loss_type="nca"` in the [`DPOConfig`].
The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. To use the loss set the `loss_type="nca_pair"` in the [`DPOConfig`].

The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) and set the `loss_type="robust_dpo"` in the [`DPOConfig`].
The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) and set the `loss_type="robust"` in the [`DPOConfig`].

The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. To use this loss, set the `loss_type="bco"` in the [`DPOConfig`].
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. To use this loss, set the `loss_type="bco_pair"` in the [`DPOConfig`].

The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model=True` in the [`DPOConfig`].

Expand All @@ -150,6 +150,8 @@ The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is

The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size.

The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. To use these losses, set `loss_type="apo_zero"` or `loss_type="apo_down"` in the [`DPOConfig`].

### For Mixture of Experts Models: Enabling the auxiliary loss

MOEs are the most efficient if the load is about equally distributed between experts.
Expand Down
2 changes: 2 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def _init_dummy_image_dataset(self):
["gpt2", "robust", True],
["gpt2", "exo_pair", False],
["t5", "exo_pair", True],
["gpt2", "apo_zero", True],
["t5", "apo_down", False],
]
)
def test_dpo_trainer(self, name, loss_type, pre_compute):
Expand Down
15 changes: 14 additions & 1 deletion trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class DPOConfig(TrainingArguments):
- `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675) paper.
- `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper.
- `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper.
- `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
- `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper.

label_pad_token_id (`int`, *optional*, defaults to `-100`):
The label pad token id. This argument is required if you want to use the default data collator.
Expand Down Expand Up @@ -105,7 +107,18 @@ class DPOConfig(TrainingArguments):
beta: float = 0.1
label_smoothing: float = 0
loss_type: Literal[
"sigmoid", "hinge", "ipo", "exo_pair", "nca_pair", "robust", "bco_pair", "sppo_hard", "aot", "aot_pair"
"sigmoid",
"hinge",
"ipo",
"exo_pair",
"nca_pair",
"robust",
"bco_pair",
"sppo_hard",
"aot",
"aot_pair",
"apo_zero",
"apo_down",
] = "sigmoid"
label_pad_token_id: int = -100
padding_value: Optional[int] = None
Expand Down
31 changes: 27 additions & 4 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
beta: float = 0.1,
label_smoothing: float = 0,
loss_type: Literal["sigmoid", "hinge", "ipo", "bco_pair", "robust", "aot", "aot_pair"] = "sigmoid",
loss_type: Optional[str] = None,
args: Optional[DPOConfig] = None,
data_collator: Optional[DataCollator] = None,
label_pad_token_id: int = -100,
Expand Down Expand Up @@ -481,7 +481,7 @@ def make_inputs_require_grad(module, input, output):
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False

if loss_type != "sigmoid":
if loss_type is not None:
warnings.warn(
"You passed `loss_type` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
Expand All @@ -491,7 +491,10 @@ def make_inputs_require_grad(module, input, output):
"You passed `label_smoothing` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.label_smoothing = label_smoothing
if args.loss_type in ["hinge", "ipo", "bco_pair"] and args.label_smoothing > 0:
if (
args.loss_type in ["hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "apo_zero", "apo_down"]
and args.label_smoothing > 0
):
warnings.warn(
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
)
Expand Down Expand Up @@ -1191,9 +1194,29 @@ def dpo_loss(
- F.logsigmoid(-self.beta * delta) * self.label_smoothing
)

elif self.loss_type == "apo_zero":
# Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
# Use this loss when you believe the chosen outputs are better than your model's default output

losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood
losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood

losses = losses_chosen + losses_rejected

elif self.loss_type == "apo_down":
# Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
# Use this loss when you believe the chosen outputs are worse than your model's default output

losses_chosen = F.sigmoid(self.beta * chosen_logratios) # Decrease chosen likelihood
losses_rejected = 1 - F.sigmoid(
self.beta * (chosen_logratios - rejected_logratios)
) # Decrease rejected likelihood more

losses = losses_chosen + losses_rejected

else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'bco_pair', 'sppo_hard', 'nca_pair', 'robust', 'exo_pair']"
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down']"
)

chosen_rewards = (
Expand Down
Loading