-
Notifications
You must be signed in to change notification settings - Fork 203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added a NTXent loss #897
Added a NTXent loss #897
Conversation
This is the NTXent loss from the SimClr paper. I have done some minor testing on it and it reproduces both loss and gradient values of pytorch_metric_learning's function NTXentLoss. Let me know if you want me to add a test function as well. I was unsure if I should put it in classification because it is really a self-supervised loss, so let me know if you want it somewhere else!
Got rid of test.sh errors
test for ntxent loss
changed the way cosine_similarity is imported.
fixed the output based on the default temperature scaling.
Made the function jittable
simpler ntxent test. renamed the class to fix typo.
changed == True to == 1
Oh also, I know NTXent isn't really a classification loss, but I didn't know where else to put it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @GrantMcConachie!
This would be a great addition. I left you some comments. In addition of those comments:
- you may create a file
_contrastive.py
where you would put this loss (so don't put this in_classification.py
), - you will need to add the loss in the
__init__.py
file of the losses folder, - you will also need to add it in the docs (in
docs/api/losses.rst
).
Ping me if you have any difficulties and thank you again!
Hello @GrantMcConachie! Yes, it's a bug due to a new release of jax, it's already been pointed out in #904. I've upstreamed the bug internally to the jax team, I hope it can get solved. I don't see any quick hot fix (the bug shall affect many of our modules). I can ping you when it gets solved :) |
The tests have been fixed in #908. You should be able to proceed. Again it would be really nice if you could avoid passing through an exponential without careful tricks like the losumexptrick (see https://en.wikipedia.org/wiki/LogSumExp). Alternatively, you may use directly functions like |
I am still working on this! It is trickier to implement than I thought due to the partitioning of positive and negative pairs. For example If there is 4 embeddings with labels like this:
then your cosine similarity matrix will be a 4x4 matrix like this:
Where the pluses indicate positive pairs. So we want the loss to be I think there's probably a way around this using the logsumexp or log_softmax, but I am still trying to figure out how to do it. |
Hi @vroulet! I believe that I implemented the loss function using the same trick that logsumexp and log_softmax. I did not use these functions explicitly, but I was able to "normalize" the cosine similarity values by subtracting the row wise maximum cosine similarity values from each cosine similarity value before exponentiating, summing, then taking the logarithm. I believe this is sufficient to avoid overflow/underflow problems, but please let me know if you see an issue with this! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect! Thank you for pushing this through!
I'm leaving last small comments and we'll merge :)
Also you'll need to add the loss to the main __init__.py
file (otherwise it won't be caught in the doc; for an example see definition of e.g. convex_kl_divergence
in __init__.py
and its appearance in the definition of __all__
in the __init__.py
file).
Thanks for all the edits! I am happy that this will get added. Please let me know if there's more I can do! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you again @GrantMcConachie !
An normalized temperature scaled cross entropy (NTXent) loss for a contrastive learning objective. I am fairly new to submitting pull requests to public repos, so I didn't add a ton of tests for this outside a batched test. Let me know if there is anything else I should add!