Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a NTXent loss #897

Merged
merged 25 commits into from
Apr 15, 2024
Merged

Conversation

GrantMcConachie
Copy link
Contributor

An normalized temperature scaled cross entropy (NTXent) loss for a contrastive learning objective. I am fairly new to submitting pull requests to public repos, so I didn't add a ton of tests for this outside a batched test. Let me know if there is anything else I should add!

This is the NTXent loss from the SimClr paper. I have done some minor testing on it and it reproduces both loss and gradient values of pytorch_metric_learning's function NTXentLoss. Let me know if you want me to add a test function as well. I was unsure if I should put it in classification because it is really a self-supervised loss, so let me know if you want it somewhere else!
Got rid of test.sh errors
test for ntxent loss
changed the way cosine_similarity is imported.
fixed the output based on the default temperature scaling.
Made the function jittable
simpler ntxent test. renamed the class to fix typo.
changed == True to == 1
@GrantMcConachie
Copy link
Contributor Author

Oh also, I know NTXent isn't really a classification loss, but I didn't know where else to put it.

Copy link
Collaborator

@vroulet vroulet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @GrantMcConachie!
This would be a great addition. I left you some comments. In addition of those comments:

  • you may create a file _contrastive.py where you would put this loss (so don't put this in _classification.py),
  • you will need to add the loss in the __init__.py file of the losses folder,
  • you will also need to add it in the docs (in docs/api/losses.rst).

Ping me if you have any difficulties and thank you again!

optax/losses/_classification.py Outdated Show resolved Hide resolved
optax/losses/_classification.py Outdated Show resolved Hide resolved
optax/losses/_classification.py Outdated Show resolved Hide resolved
optax/losses/_classification.py Outdated Show resolved Hide resolved
optax/losses/_classification.py Outdated Show resolved Hide resolved
@vroulet
Copy link
Collaborator

vroulet commented Apr 4, 2024

Hello @GrantMcConachie!

Yes, it's a bug due to a new release of jax, it's already been pointed out in #904. I've upstreamed the bug internally to the jax team, I hope it can get solved. I don't see any quick hot fix (the bug shall affect many of our modules). I can ping you when it gets solved :)

@vroulet
Copy link
Collaborator

vroulet commented Apr 5, 2024

The tests have been fixed in #908. You should be able to proceed. Again it would be really nice if you could avoid passing through an exponential without careful tricks like the losumexptrick (see https://en.wikipedia.org/wiki/LogSumExp). Alternatively, you may use directly functions like jax.nn.log_softmax that implemented such a trick.
Thanks again for the PR!

@GrantMcConachie
Copy link
Contributor Author

I am still working on this! It is trickier to implement than I thought due to the partitioning of positive and negative pairs. For example If there is 4 embeddings with labels like this:

e1 -> 0
e2 -> 0
e3 -> 0
e4 -> 1

then your cosine similarity matrix will be a 4x4 matrix like this:

[ 0          sim01+     sim02+     sim03-]
[sim10+        0        sim12+     sim13-]
[sim20+      sim21+       0        sim23-]
[sim30       sim31      sim32        0   ]

Where the pluses indicate positive pairs.

So we want the loss to be log(exp(sim01+) / (exp(sim01+) + exp(sim03-)) + log(exp(sim02+) / (exp(sim02+) + exp(sim03-)) + log(exp(sim10+) / (exp(sim10+) + exp(sim13-)) + .... However a row wise log_softmax on this matrix will return values log(exp(sim01+) / (exp(sim01+) + exp(sim02+) + exp(sim03-)) and will incorporate exp(sim02+) in the denominator when it shouldn't.

I think there's probably a way around this using the logsumexp or log_softmax, but I am still trying to figure out how to do it.

@GrantMcConachie
Copy link
Contributor Author

Hi @vroulet! I believe that I implemented the loss function using the same trick that logsumexp and log_softmax. I did not use these functions explicitly, but I was able to "normalize" the cosine similarity values by subtracting the row wise maximum cosine similarity values from each cosine similarity value before exponentiating, summing, then taking the logarithm. I believe this is sufficient to avoid overflow/underflow problems, but please let me know if you see an issue with this!

Copy link
Collaborator

@vroulet vroulet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! Thank you for pushing this through!
I'm leaving last small comments and we'll merge :)
Also you'll need to add the loss to the main __init__.py file (otherwise it won't be caught in the doc; for an example see definition of e.g. convex_kl_divergence in __init__.py and its appearance in the definition of __all__ in the __init__.py file).

optax/losses/_self_supervised_test.py Outdated Show resolved Hide resolved
optax/losses/_self_supervised.py Outdated Show resolved Hide resolved
optax/losses/_self_supervised.py Show resolved Hide resolved
docs/api/losses.rst Outdated Show resolved Hide resolved
docs/api/losses.rst Outdated Show resolved Hide resolved
optax/losses/_self_supervised_test.py Outdated Show resolved Hide resolved
@GrantMcConachie
Copy link
Contributor Author

Thanks for all the edits! I am happy that this will get added. Please let me know if there's more I can do!

Copy link
Collaborator

@vroulet vroulet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you again @GrantMcConachie !

@copybara-service copybara-service bot merged commit 60bb59f into google-deepmind:main Apr 15, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants