Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nt_xent_loss numerical instability #320

Closed
jefflai108 opened this issue Oct 29, 2020 · 2 comments · Fixed by #329
Closed

nt_xent_loss numerical instability #320

jefflai108 opened this issue Oct 29, 2020 · 2 comments · Fixed by #329
Assignees
Labels
help wanted Extra attention is needed
Milestone

Comments

@jefflai108
Copy link

jefflai108 commented Oct 29, 2020

Issue

https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/e6b10875d59a39a4dcf382d3a599528a40ba088c/pl_bolts/losses/self_supervised_learning.py#L17-L24

I am experimenting with the SimCLR framework for audio data, and I found that in the objective function, exponential operation on the negative and positive samples may induce numerical instability i.e. loss becomes nan. I think this is not unexpected.

Temporary solution

Uses torch.nn.CrossEntropyLoss
this is what some other open-source implementation does, like this one: https://github.com/sthalles/SimCLR/blob/master/loss/nt_xent.py

@jefflai108 jefflai108 added the help wanted Extra attention is needed label Oct 29, 2020
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@jefflai108 jefflai108 changed the title nt_xent_loss numerical stability nt_xent_loss numerical instability Oct 29, 2020
@ananyahjha93 ananyahjha93 mentioned this issue Nov 2, 2020
8 tasks
@ananyahjha93 ananyahjha93 self-assigned this Nov 3, 2020
@ananyahjha93
Copy link
Contributor

ananyahjha93 commented Nov 16, 2020

@jefflai108 out_1 and out_2 are normalized vectors and should not create numerical instability with the current loss function. The maximum value after torch.exp of any element can be e and the minimum is 1.

Looking at the range for exponentiated similarity values of normalized vectors, we also eliminate the possibility of division by 0 in the line loss = -torch.log(pos / neg).mean()

Can you verify for me that out_1 and out_2 are normalized in your code?

@Borda Borda added this to the v0.3 milestone Jan 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants