-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SimPO (Simple Preference Optimisation) #1223
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1223
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New FailuresAs of commit 0c008d4 with merge base 5c7246e (): NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
I'm not a fan of all the if checks you need to do to support the slight variation in setups for the losses, I'm worried this will invite further complex branching logic when we add more losses that may need different treatment. But I also recognize there's there not an easy solution here. It seems like
|
Thanks so much for the review. I actually tried out a lighter-weight version of what you suggested initially; loss functions carried boolean flags indicating if they used average logprobs/were ref free. As you pointed out, however -
This (kind of) bothered me, particularly if we want a community contributor to easily add one of the many loss functions that our recipe could support.
I don't love this. I'd like for the loss functions to be atomic and just implement the logic required for the calculation - a 1:1 mapping to the maths. I've come to two solutions which could satisfy both of us. Firstly, we expose attributes on the losses: get_batch_logprobs(
...
return_average_logprobs = self._loss_fn.uses_average_logprobs
)
...
if self._loss_fn.uses_reference_logprobs:
ref_logprobs = ... After typing this out, I actually don't love this. The contract between the recipe and the loss is enforced in a vague way at runtime. Blegh. But, it's relatively lightweight.
class PreferenceLoss(Protocol):
@property
def uses_average_logprobs(self) -> bool:
"""
indicates whether the loss uses averaged vs summed logprobs)
"""
...
class DPOLoss(PreferenceLoss, nn.Module):
@property
def uses_average_logprobs(self) -> bool:
return False This keeps things a bit tigher and easier to reason about the runtime behaviour of the recipe. (would probably be PreferenceLossInterface) Sorry for the thought vomit. Your wisdom appreciated. |
Leaning towards this approach. It is a bit awkward to enforce a protocol on a loss nn.Module but contracts are clearer and it makes the checks in the recipe much cleaner. tagging @ebsmothers for his thoughts |
Maybe I am thinking about this too simply (if so please tell me), but I feel like we are trying to overgeneralize prematurely here. We have a SimPO loss that can optionally return the average logprobs and then we do a bunch of checks to confirm that if we're doing SimPO we set In my mind it's OK to be dumb and explicit here, and it also makes the code easier to understand. Why not: (1) just have SimPO always take the average (especially if this is a key contribution from the paper itself), (2) scrap
(ofc with appropriate code comments to explain what's happening). Modulo that, this looks good to go to me. I notice you didn't add a new config or anything, do you plan to? (Not saying you should btw as we have quite a lot already, just curious) |
If one should ever think to oneself "Would Evan suggest I'm overgeneralizing prematurely here?", then one should consider whether one is indeed overgeneralizing prematurely. Great suggestion @ebsmothers. I agree - let's reconsider when this is necessary.
btw the losses aren't returning the average logprobs, that's from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, that's definitely a lot clearer thanks to @ebsmothers. looks great to me.
@@ -180,11 +187,13 @@ def setup(self, cfg: DictConfig) -> None: | |||
# log config with parameter override | |||
self._metric_logger.log_config(cfg) | |||
|
|||
self._model_compile = cfg.compile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: self._is_model_compiled
or similar is a bit more clear
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fyi this is what we use in all our recipes - can amend
Context
What is the purpose of this PR? Is it to
(resolves #1037, closes #1036)
Another alignment PR???
This PR adds support for SimPO, a kind of direct preference optimisation technique (i.e. one which does not require a separate reward model) which also eliminates the need for a reference model. The authors propose two key contributions to achieve this:
Since this loss function changes the DPO loss in a minor way, we can use the existing DPO recipe here. I've added some validation in the DPO recipe to ensure a valid loss function is used, and a
use_average_logprobs
parameter, which is hidden from the config by default, but we warn the user if they try to use SimPO without this parameter. I plan to document these better in the recipe docs PR I've been discussing on Discord, which I'll put up soon.I've added functionality in the recipe to only generate reference logprobs if the loss function is one which uses a reference model (like the current DPO-style losses we have). Since there are other reference-free direct-preference-optimisation-style losses (like CPO, which I think would be a nice "first-issue" to have in the backlog), it should be straightforward to extend this recipe to use additional loss functions.
Tests
Replication results
Please see training plots below for our implementation:
and for TRL in their
CPOTrainer
, with my best attempt at ensuring identical hyperparameters:Unwarranted opinions
This is all well and good, Salman, but where is the meat and potatoes? Where is the protein? How do we make sense of all this??
Overall, I found SimPO to be a bit fussy to work with. It was very sensitive to learning rate + scheduler settings, and unless you tuned it right, optimisation would degenerate pretty quick and the run would fail. For both TRL and torchtune runs, I couldn't get chosen rewards to not decrease and become more negative during training. Note that this doesn't mean the method failed, necessarily, since SimPO optimises for a margin between the chosen and rejected logprobs, which we achieved here. The authors note this is kind-of expected behaviour, so maybe fine?
Interestingly enough, I played around with CPO in TRL and my results were much cleaner - I found it overall much more stable, and my chosen rewards played nicely - I produced some cool plots of the optimiser finding a stable decision boundary pretty quickly. Check out an example run here. I think adding CPO would be a pretty straightforward "good-first-issue", and I'll follow up with an issue for it.
A unit test has been added for the loss.
pre-commit install
)pytest tests
pytest tests -m integration_test