-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Preference Loss and Refactor for Readability #484
Conversation
test/chunked_loss/test_cpo_loss.py
Outdated
losses = -( | ||
F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) | ||
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing | ||
) | ||
elif self.loss_type == "simpo": | ||
logits = logits - (self.simpo_gamma / self.beta) | ||
losses = ( | ||
losses = -( | ||
F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) | ||
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we have the - sign inside the brackets:
similar to https://github.com/huggingface/trl/blob/0fe73a8af5ff660becc79bff88d9e8b090dd004f/trl/trainer/dpo_trainer.py#L949
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! Thank you for reviewing.
Signed-off-by: Austin Liu <austin362667@gmail.com>
526bf4e
to
a12c2f1
Compare
Signed-off-by: Austin Liu <austin362667@gmail.com>
a12c2f1
to
ce74888
Compare
Summary
Thanks to @winglian and @shivam15s noticed and fixed this #481.
This PR suggests negating the preference loss terms to align with the formulas in the docstrings, while maintaining the base preference structure as
nll_loss + preference_loss
. This would make our loss computations more consistent since both terms would represent losses to be minimized.[UPDATE: It seems like being addressed now in here]
This PR also tightened the tolerance in case of encountering a similar issue.
Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence