-
Notifications
You must be signed in to change notification settings - Fork 215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ntxent fix #946
ntxent fix #946
Conversation
…llows for 0 vector emeddings.
…ine similarity.
…to patch-1
Thanks again @GrantMcConachie !
Would such a logic be potentially implementable? |
Hi @vroulet! Yes I think this is possible! I will work on it and let you know. |
Hello again @vroulet, I tried the following
to calculate the cross entropy, rather than the cosine_similarity function with the epsilon. This gives the same cosine similarity matrix, however the gradient resulted in NaNs. I also tried this
keeping the 0.0 epsilon value for the cosine_similarity calculation and this also resulted in NaNs in the gradient. I am out of ideas of other ways to implement jnp.where(), so I believe the best way to go about this is to add the epsilon in the cosine similarity! Let me know if you have any more suggestions for implementing jnp.where() though! |
Hello @GrantMcConachie, Thanks again for this contribution, this issue made me look at it again and it's well done :) Ah and btw you may add a doctest in the loss if you are on it. Understanding what should be the proper shapes etc is not necessarily evident for the user and this would help. (look at the docstring of Adam for example, you'll see a section Examples where you can format some code that would appear nicely in the docs). |
Hi @vroulet! You are definitely right. Adding this epsilon term to the cosine similarity only fixed a 0 vector embedding issue. For the issue where all labels are the same, I think the loss should be 0. The reason is because if there are no negative pairs (diffs is filled with 0s), the denominator and numerator inside the log should be the same. The loss I was confused at first because I thought the The equation from https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss, where I took a lot of inspiration to build this function, is more general in that you don't need to have just 1 positive pair in your embeddings. Here you get the same evaluation: running this loss with all the same label gives you 0. In conclusion, adding the epsilon term in cosine similarity alleviates the 0 vector embedding problem and the case in which all labels are the same should evaluate to 0 loss. Let me know if you agree! I will start working on the doctest soon! |
Hi @vroulet! Just wanted to let you know I added a doctest! Let me know what you think. |
instead of a hard-coded 1e-12, could we perhaps replace it with |
Hi @fabianp! Yes I can add this in. |
optax/losses/_self_supervised.py
Outdated
@@ -55,7 +86,8 @@ def ntxent( | |||
# cosine similarity matrix | |||
xcs = ( | |||
_regression.cosine_similarity( | |||
embeddings[None, :, :], embeddings[:, None, :] | |||
embeddings[None, :, :], embeddings[:, None, :], | |||
epsilon=np.finfo(embeddings.dtype).eps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to import numpy, you can do the same with jnp instead of np
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it! Will change shortly.
@fabianp is something in particular holding this PR up? |
there were some internal errors and then i forgot about it. Taking a look into it now |
I added an epsilon value to the cosine similarity function to avoid the NaNs that were occurring when when you had a label vector [0, 0, 0] or when one of your embeddings was the 0 vector.