-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
preference loss sign is inverted and leads to negative loss #481
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind updating the corresponding tests? They failed in my CI as expected.
I think the corresponding snippet lies here: Line 514 in 0bb6c72
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
@@ -408,7 +408,7 @@ def _compute_loss( | |||
else: | |||
preference_loss, aux_outputs = preference_loss_outputs, [] | |||
|
|||
loss = alpha * chosen_nll_loss - preference_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like the original logic is correct for all losses except dpo_loss. I think the sign change should be here instead:
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @shivam15s Thanks for noticing this. What are your thoughts on negating each preference loss term to align with the formulas in the docstrings? This would allow us to maintain the base preference structure as nll_loss + preference_loss
while making both terms represent losses to be minimized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure @austin362667 ! I think that might help with readability.
If you have some time, could you add the fix to this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ya sure! I might need to open a new PR (base on this one) cc @winglian
## Summary Thanks to @winglian and @shivam15s noticed and fixed this #481. This PR suggests negating the preference loss terms to align with the formulas in the docstrings, while maintaining the base preference structure as `nll_loss + preference_loss`. This would make our loss computations more consistent since both terms would represent losses to be minimized. [UPDATE: It seems like being addressed now in [here](3205342#diff-3048cb37b97e27515852c200994f3257b8ae33a465421d05184713377c0895b1R150)] This PR also tightened the tolerance in case of encountering a similar issue. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu <austin362667@gmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Shivam Sahni <shivam15800@gmail.com>
Summary
In testing cases where there is no alpha/NLL loss used, the loss becomes negative, which is probably not the intended behavior. see https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L1234
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence