Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

preference loss sign is inverted and leads to negative loss #481

Closed
wants to merge 2 commits into from

Conversation

winglian
Copy link
Contributor

Summary

In testing cases where there is no alpha/NLL loss used, the loss becomes negative, which is probably not the intended behavior. see https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L1234

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

austin362667

This comment was marked as outdated.

Copy link
Collaborator

@austin362667 austin362667 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind updating the corresponding tests? They failed in my CI as expected.

@austin362667
Copy link
Collaborator

austin362667 commented Dec 16, 2024

I think the corresponding snippet lies here:

loss = policy_nll_loss * self.alpha - losses.mean()
.

Copy link
Collaborator

@austin362667 austin362667 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@@ -408,7 +408,7 @@ def _compute_loss(
else:
preference_loss, aux_outputs = preference_loss_outputs, []

loss = alpha * chosen_nll_loss - preference_loss
Copy link
Collaborator

@shivam15s shivam15s Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the original logic is correct for all losses except dpo_loss. I think the sign change should be here instead:

loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @shivam15s Thanks for noticing this. What are your thoughts on negating each preference loss term to align with the formulas in the docstrings? This would allow us to maintain the base preference structure as nll_loss + preference_loss while making both terms represent losses to be minimized.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure @austin362667 ! I think that might help with readability.
If you have some time, could you add the fix to this PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya sure! I might need to open a new PR (base on this one) cc @winglian

shivam15s added a commit that referenced this pull request Dec 20, 2024
## Summary

Thanks to @winglian and @shivam15s noticed and fixed this
#481.

This PR suggests negating the preference loss terms to align with the
formulas in the docstrings, while maintaining the base preference
structure as `nll_loss + preference_loss`. This would make our loss
computations more consistent since both terms would represent losses to
be minimized.

[UPDATE: It seems like being addressed now in
[here](3205342#diff-3048cb37b97e27515852c200994f3257b8ae33a465421d05184713377c0895b1R150)]
This PR also tightened the tolerance in case of encountering a similar
issue.

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [X] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Austin Liu <austin362667@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Shivam Sahni <shivam15800@gmail.com>
@winglian winglian closed this Dec 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants