Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a NTXent loss #897

Merged
merged 25 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e2bd891
Created a NTXent loss
GrantMcConachie Mar 27, 2024
a8f9816
Merge branch 'google-deepmind:main' into patch-1
GrantMcConachie Mar 27, 2024
b2179c3
Update _classification.py
GrantMcConachie Mar 27, 2024
f6599b8
Update _classification_test.py
GrantMcConachie Mar 27, 2024
9e1908a
Merge branch 'google-deepmind:main' into patch-1
GrantMcConachie Mar 27, 2024
271ccb0
Update _classification.py
GrantMcConachie Mar 27, 2024
6ae70bf
Update _classification_test.py
GrantMcConachie Mar 27, 2024
b5dbb93
Merge branch 'google-deepmind:main' into patch-1
GrantMcConachie Apr 1, 2024
cc79720
Update _classification.py
GrantMcConachie Apr 1, 2024
4d4368a
Update _classification_test.py
GrantMcConachie Apr 1, 2024
28f86eb
Update _classification.py
GrantMcConachie Apr 1, 2024
45fbefa
Merge branch 'google-deepmind:main' into patch-1
GrantMcConachie Apr 2, 2024
2900b7f
added contrastive loss specific scripts
GrantMcConachie Apr 2, 2024
d0a4ce1
moved ntxent loss to contrastive specific script and added to __init__
GrantMcConachie Apr 2, 2024
8ee73e7
added whitespace
GrantMcConachie Apr 2, 2024
39928cc
changed the file names and fixed the docstrings for the functions
GrantMcConachie Apr 3, 2024
4f05cc2
added ntxent to losses.rst and changed import in __init__
GrantMcConachie Apr 3, 2024
1d8e47f
typo fixed
GrantMcConachie Apr 3, 2024
35f3347
changed typehints
GrantMcConachie Apr 4, 2024
284060d
changed back type hints
GrantMcConachie Apr 4, 2024
4c2d1b3
Merge branch 'google-deepmind:main' into patch-1
GrantMcConachie Apr 12, 2024
a651bcb
another test case that is important to have
GrantMcConachie Apr 12, 2024
38afc10
changed the type hints and added a more robust numerical stability ca…
GrantMcConachie Apr 12, 2024
f66718c
took out trailing whitespace
GrantMcConachie Apr 12, 2024
1cef8b5
minor changes to comments, docstrings, placements
GrantMcConachie Apr 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Losses
kl_divergence
l2_loss
log_cosh
ntxent
safe_softmax_cross_entropy
sigmoid_binary_cross_entropy
sigmoid_focal_loss
Expand Down Expand Up @@ -61,6 +62,10 @@ Log hyperbolic cosine loss
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: log_cosh

Normalized temperature scaled cross-entropy (NT-Xent) loss
~~~~~~~~~~~~~~~~
.. autofunction:: ntxent

Sigmoid binary cross-entropy
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: sigmoid_binary_cross_entropy
Expand Down
2 changes: 2 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@
kl_divergence = losses.kl_divergence
l2_loss = losses.l2_loss
log_cosh = losses.log_cosh
ntxent = losses.ntxent
sigmoid_binary_cross_entropy = losses.sigmoid_binary_cross_entropy
smooth_labels = losses.smooth_labels
softmax_cross_entropy = losses.softmax_cross_entropy
Expand Down Expand Up @@ -306,6 +307,7 @@
"MultiTransformState",
"nadam",
"nadamw",
"ntxent",
"noisy_sgd",
"novograd",
"NonNegativeParamsState",
Expand Down
1 change: 1 addition & 0 deletions optax/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@
from optax.losses._regression import log_cosh
from optax.losses._regression import squared_error
from optax.losses._smoothing import smooth_labels
from optax.losses._self_supervised import ntxent
1 change: 0 additions & 1 deletion optax/losses/_classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,5 @@ def test_ignore_negative(self):
assert all(ce_loss[self.ts == 0] > 0)
assert all(focal_loss[self.ts == 0] == 0)


if __name__ == '__main__':
absltest.main()
86 changes: 86 additions & 0 deletions optax/losses/_self_supervised.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Self supervised losses."""

import chex
from jax import lax
import jax.numpy as jnp
from optax.losses._regression import cosine_similarity


def ntxent(
embeddings: chex.Array,
labels: chex.Array,
temperature: chex.Numeric = 0.07
) -> chex.Numeric:
"""Normalized temperature scaled cross entropy loss (NT-Xent).

References:
T. Chen et al `A Simple Framework for Contrastive Learning of Visual
Representations <http://arxiv.org/abs/2002.05709>`_, 2020
kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss

Args:
emeddings: batch of embeddings, with shape [batch, feature_length]
labels: labels for groups that are positive pairs. e.g. if you have
a batch of 4 embeddings and the first two and last two were positive
pairs your `labels` should look like [0, 0, 1, 1]. labels SHOULD NOT
be all the same (e.g. [0, 0, 0, 0]) you will get a NaN result.
Shape [batch]
temperature: temperature scaling parameter.

Returns:
A scalar loss value of NT-Xent values averaged over all positive
pairs
GrantMcConachie marked this conversation as resolved.
Show resolved Hide resolved

.. versionadded:: 0.2.3
"""
chex.assert_type([embeddings], float)
if labels.shape[0] != embeddings.shape[0]:
raise ValueError(
'label dimension should match batch dimension in embeddings'
)

# cosine similarity matrix
xcs = cosine_similarity(
embeddings[None, :, :], embeddings[:, None, :]
) / temperature

# finding positive and negative pairs
labels1 = jnp.expand_dims(labels, axis=1)
labels2 = jnp.expand_dims(labels, axis=0)
matches = labels1 == labels2
diffs = matches ^ 1
matches = jnp.bool_(matches - jnp.eye(matches.shape[0])) # no self cos

# replace 0 with -inf
xcs_diffs = jnp.where(diffs == 1, xcs, -jnp.inf)
xcs_matches = jnp.where(matches == 1, xcs, -jnp.inf)

# shifting for numeric stability
comb = jnp.concatenate((xcs_diffs, xcs_matches), axis=-1)
xcs_max = jnp.max(comb, axis=1, keepdims=True)
xcs_shift_diffs = xcs_diffs - lax.stop_gradient(xcs_max)
xcs_shift_matches = xcs_matches - lax.stop_gradient(xcs_max)

# calc loss
numer = xcs_shift_matches
numer_exp = jnp.exp(xcs_shift_matches)
denom = jnp.sum(jnp.exp(xcs_shift_diffs), axis=1, keepdims=True)
denom += numer_exp
log_softm = numer - jnp.log(denom)
loss = -jnp.where(matches == 1, log_softm, 0.0).sum() / matches.sum()

return loss
51 changes: 51 additions & 0 deletions optax/losses/_self_supervised_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for optax.losses._self_supervised."""

from absl.testing import parameterized

import chex
import jax.numpy as jnp
import numpy as np

from optax.losses import _self_supervised


class NtxentTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.ys = jnp.array([
[-1.9540, 1.0780],
[ 0.2380, -0.5703],
[ 1.8745, -0.0195],
[-0.6719, -1.9210],
])
self.ts_1 = jnp.array([0,0,1,1])
self.ts_2 = jnp.array([0,0,0,1])
# Calculated expected output
self.exp_1 = jnp.array(14.01032)
self.exp_2 = jnp.array(8.968544)

@chex.all_variants
def test_batched(self):
"""Tests for a full batch."""
np.testing.assert_allclose(
self.variant(_self_supervised.ntxent)(self.ys, self.ts_1),
self.exp_1, atol=1e-4)

np.testing.assert_allclose(
self.variant(_self_supervised.ntxent)(self.ys, self.ts_2),
self.exp_2, atol=1e-4)
Loading