Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SimPO (Simple Preference Optimisation) #1223

Merged
merged 8 commits into from
Aug 8, 2024

Conversation

SalmanMohammadi
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi commented Jul 25, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

(resolves #1037, closes #1036)

Another alignment PR???

This PR adds support for SimPO, a kind of direct preference optimisation technique (i.e. one which does not require a separate reward model) which also eliminates the need for a reference model. The authors propose two key contributions to achieve this:

  1. Using the average logprobs over the responses. They show that the downstream, trained policy model is infact maximising the average log-likelihood of a generated sequence (since we're ultimately optimizing for a policy model which can generate preferred responses). However, the authors claim that DPO objective instead optimizes an explicit formulation of the reward function, which causes a discrepancy between the training objective function, and generation log-likelihood function. They forumlate a length-normalised reward objective function using the average logprobs and claim with correct tuning, this is sufficient to prevent degeneracy without a reference model.
  2. Introducing an additional hyperparameter, γ (gamma), which is an SVM-style margin hyperparameter between the chosen and rejected logprobs.

Since this loss function changes the DPO loss in a minor way, we can use the existing DPO recipe here. I've added some validation in the DPO recipe to ensure a valid loss function is used, and a use_average_logprobs parameter, which is hidden from the config by default, but we warn the user if they try to use SimPO without this parameter. I plan to document these better in the recipe docs PR I've been discussing on Discord, which I'll put up soon.

I've added functionality in the recipe to only generate reference logprobs if the loss function is one which uses a reference model (like the current DPO-style losses we have). Since there are other reference-free direct-preference-optimisation-style losses (like CPO, which I think would be a nice "first-issue" to have in the backlog), it should be straightforward to extend this recipe to use additional loss functions.

Tests

Replication results

Please see training plots below for our implementation:

image

and for TRL in their CPOTrainer, with my best attempt at ensuring identical hyperparameters:

image

Unwarranted opinions

This is all well and good, Salman, but where is the meat and potatoes? Where is the protein? How do we make sense of all this??

Overall, I found SimPO to be a bit fussy to work with. It was very sensitive to learning rate + scheduler settings, and unless you tuned it right, optimisation would degenerate pretty quick and the run would fail. For both TRL and torchtune runs, I couldn't get chosen rewards to not decrease and become more negative during training. Note that this doesn't mean the method failed, necessarily, since SimPO optimises for a margin between the chosen and rejected logprobs, which we achieved here. The authors note this is kind-of expected behaviour, so maybe fine?

Interestingly enough, I played around with CPO in TRL and my results were much cleaner - I found it overall much more stable, and my chosen rewards played nicely - I produced some cool plots of the optimiser finding a stable decision boundary pretty quickly. Check out an example run here. I think adding CPO would be a pretty straightforward "good-first-issue", and I'll follow up with an issue for it.

A unit test has been added for the loss.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Copy link

pytorch-bot bot commented Jul 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1223

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures

As of commit 0c008d4 with merge base 5c7246e (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 25, 2024
@SalmanMohammadi SalmanMohammadi changed the title Simpo SimPO (Simple Preference Optimisation) Jul 25, 2024
@SalmanMohammadi SalmanMohammadi requested review from RdoubleA and ebsmothers and removed request for RdoubleA July 30, 2024 10:33
@RdoubleA
Copy link
Contributor

RdoubleA commented Aug 2, 2024

I'm not a fan of all the if checks you need to do to support the slight variation in setups for the losses, I'm worried this will invite further complex branching logic when we add more losses that may need different treatment. But I also recognize there's there not an easy solution here. It seems like use_average_logprobs and reference_free are both loss specific, so instead of dealing with loss specific logic in the recipe/config layer, what are your thoughts on this:

  • call get_batch_logprobs in the loss module instead of in the recipe. This would mean concatenated_forward would just return all_logits (or you can split them back to chosen, rejected) and loss forward would take logits instead and call get_batch_logprobs. This way, SimPO can do average log probs without having to expose this logic all the way in the config layer, which imo makes more sense because it's a low-level module specific parameter instead of a recipe level parameter
  • knowing if a loss is reference free is a bit harder. How about making this an attribute on all DPO loss modules? then you just check it directly instead of parse the loss name, keeping a mapping on reference free losses, etc. The downside is that there's no way to enforce future DPO loss modules having this attribute

@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Aug 4, 2024

Thanks so much for the review.

I actually tried out a lighter-weight version of what you suggested initially; loss functions carried boolean flags indicating if they used average logprobs/were ref free. As you pointed out, however -

The downside is that there's no way to enforce future DPO loss modules having this attribute

This (kind of) bothered me, particularly if we want a community contributor to easily add one of the many loss functions that our recipe could support.

call get_batch_logprobs in the loss module instead of in the recipe. This would mean concatenated_forward would just return all_logits (or you can split them back to chosen, rejected) and loss forward would take logits instead and call get_batch_logprobs.

I don't love this. I'd like for the loss functions to be atomic and just implement the logic required for the calculation - a 1:1 mapping to the maths. get_batch_logprobs is relevant for massaging inputs for consumption by the loss - an operation which belongs in the recipe imo. This follows how I use this function in PPO (which I copied the util over from).

I've come to two solutions which could satisfy both of us. Firstly, we expose attributes on the losses: uses_average_logprobs and uses_reference_logprobs. We call get_batch_logprobs as-is in the recipe, and we just use the uses_reference_logprobs attr in the recipe for generating ref logprobs, or not.

get_batch_logprobs(
    ...
    return_average_logprobs = self._loss_fn.uses_average_logprobs
)

...

if self._loss_fn.uses_reference_logprobs:
    ref_logprobs = ...

After typing this out, I actually don't love this. The contract between the recipe and the loss is enforced in a vague way at runtime. Blegh. But, it's relatively lightweight.

  1. A bit heavy-handed, but we define a protocol with something like:
class PreferenceLoss(Protocol):
    @property
    def uses_average_logprobs(self) -> bool:
    """
    indicates whether the loss uses averaged vs summed logprobs)
    """

    ...

class DPOLoss(PreferenceLoss, nn.Module):
    @property
    def uses_average_logprobs(self) -> bool:
        return False

This keeps things a bit tigher and easier to reason about the runtime behaviour of the recipe.

(would probably be PreferenceLossInterface)

Sorry for the thought vomit. Your wisdom appreciated.

@RdoubleA
Copy link
Contributor

RdoubleA commented Aug 5, 2024

  1. A bit heavy-handed, but we define a protocol with something like:

Leaning towards this approach. It is a bit awkward to enforce a protocol on a loss nn.Module but contracts are clearer and it makes the checks in the recipe much cleaner. tagging @ebsmothers for his thoughts

@ebsmothers
Copy link
Contributor

Maybe I am thinking about this too simply (if so please tell me), but I feel like we are trying to overgeneralize prematurely here. We have a SimPO loss that can optionally return the average logprobs and then we do a bunch of checks to confirm that if we're doing SimPO we set use_average_logprobs=True and we're setting self._loss_type appropriately.

In my mind it's OK to be dumb and explicit here, and it also makes the code easier to understand. Why not: (1) just have SimPO always take the average (especially if this is a key contribution from the paper itself), (2) scrap self._loss_type altogether (as well as the whole loss_args in the train method), and just do

if isinstance(self._loss_fn, SimPOLoss):
	loss, chosen_rewards, rejected_rewards = self._loss_fn(reference_chosen_log_probs, reference_rejected_log_probs)
else:
	# the existing logic

(ofc with appropriate code comments to explain what's happening).

Modulo that, this looks good to go to me. I notice you didn't add a new config or anything, do you plan to? (Not saying you should btw as we have quite a lot already, just curious)

@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Aug 8, 2024

If one should ever think to oneself "Would Evan suggest I'm overgeneralizing prematurely here?", then one should consider whether one is indeed overgeneralizing prematurely.

Great suggestion @ebsmothers. I agree - let's reconsider when this is necessary.
I don't think I'll be adding a config for this. We currently have 4 loss types for the DPO recipe and I'd rather just see them documented in the recipe docpage for it.

We have a SimPO loss that can optionally return the average logprobs and then we do a bunch of checks to confirm that if we're doing SimPO we set use_average_logprobs=True and we're setting self._loss_type appropriately.

btw the losses aren't returning the average logprobs, that's from get_batch_logprobs. I think it's fine atm to just do this check again there to keep things really simple, though.

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that's definitely a lot clearer thanks to @ebsmothers. looks great to me.

@@ -180,11 +187,13 @@ def setup(self, cfg: DictConfig) -> None:
# log config with parameter override
self._metric_logger.log_config(cfg)

self._model_compile = cfg.compile
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: self._is_model_compiled or similar is a bit more clear

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi this is what we use in all our recipes - can amend

@SalmanMohammadi SalmanMohammadi merged commit 0a40771 into pytorch:main Aug 8, 2024
21 of 29 checks passed
@SalmanMohammadi SalmanMohammadi deleted the simpo branch August 20, 2024 12:54
@SalmanMohammadi SalmanMohammadi mentioned this pull request Oct 18, 2024
13 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

simPO
4 participants