From b5d77e54a688bb432953c7fa608d59b075d7a979 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 23 Apr 2024 17:06:40 -0400 Subject: [PATCH 01/12] Added a way to subvert the problem with NaNs in the loss. This also allows for 0 vector emeddings. --- optax/losses/_self_supervised.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index 7d1d00bd7..bac0250a0 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -34,8 +34,7 @@ def ntxent( embeddings: batch of embeddings, with shape [batch, feature_length] labels: labels for groups that are positive pairs. e.g. if you have a batch of 4 embeddings and the first two and last two were positive pairs your - `labels` should look like [0, 0, 1, 1]. labels SHOULD NOT be all the same - (e.g. [0, 0, 0, 0]) you will get a NaN result. Shape [batch] + `labels` should look like [0, 0, 1, 1]. Shape [batch] temperature: temperature scaling parameter. Returns: @@ -60,6 +59,9 @@ def ntxent( / temperature ) + # if 0 vector or all same label + xcs = jnp.where(jnp.isnan(xcs), 0.0, xcs) + # finding positive and negative pairs labels1 = jnp.expand_dims(labels, axis=1) labels2 = jnp.expand_dims(labels, axis=0) From 030ff4e607f92c53bfc8708c9b57851e9f5ae8b5 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 23 Apr 2024 17:07:22 -0400 Subject: [PATCH 02/12] Added a 0 vector embedding test to make sure this functionality works --- optax/losses/_self_supervised_test.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/optax/losses/_self_supervised_test.py b/optax/losses/_self_supervised_test.py index 4a8f0afd6..34e3c8179 100644 --- a/optax/losses/_self_supervised_test.py +++ b/optax/losses/_self_supervised_test.py @@ -31,11 +31,18 @@ def setUp(self): [1.8745, -0.0195], [-0.6719, -1.9210], ]) + self.ys_2 = jnp.array([ + [0.0, 0.0], + [ 0.2380, -0.5703], + [ 1.8745, -0.0195], + [-0.6719, -1.9210], + ]) self.ts_1 = jnp.array([0, 0, 1, 1]) self.ts_2 = jnp.array([0, 0, 0, 1]) # Calculated expected output self.exp_1 = jnp.array(14.01032) self.exp_2 = jnp.array(8.968544) + self.exp_3 = jnp.array(9.2889) @chex.all_variants def test_batched(self): @@ -52,6 +59,12 @@ def test_batched(self): atol=1e-4, ) + np.testing.assert_allclose( + self.variant(_self_supervised.ntxent)(self.ys_2, self.ts_1), + self.exp_3, + atol=1e-4, + ) + if __name__ == '__main__': absltest.main() From 75fb5dbc27df11343a61d13bf52ecd23158b5ca3 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 23 Apr 2024 18:03:09 -0400 Subject: [PATCH 03/12] took away the replacement of nans and aded a epsilon value to the cosine similarity. --- optax/losses/_self_supervised.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index bac0250a0..f47d8fd6a 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -54,14 +54,12 @@ def ntxent( # cosine similarity matrix xcs = ( _regression.cosine_similarity( - embeddings[None, :, :], embeddings[:, None, :] + embeddings[None, :, :], embeddings[:, None, :], + eps=1e-12 ) / temperature ) - # if 0 vector or all same label - xcs = jnp.where(jnp.isnan(xcs), 0.0, xcs) - # finding positive and negative pairs labels1 = jnp.expand_dims(labels, axis=1) labels2 = jnp.expand_dims(labels, axis=0) From 1358d702bb1b19eeba401d7977b0f4fdffccb379 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 23 Apr 2024 18:16:48 -0400 Subject: [PATCH 04/12] spelling fix --- optax/losses/_self_supervised.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index f47d8fd6a..03dfef6c8 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -55,7 +55,7 @@ def ntxent( xcs = ( _regression.cosine_similarity( embeddings[None, :, :], embeddings[:, None, :], - eps=1e-12 + epsilon=1e-12 ) / temperature ) From 11ee46b026c3d57dbdf8e2a7fa20e4ef71223c36 Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 6 May 2024 16:46:13 -0400 Subject: [PATCH 05/12] Added a doctest --- optax/losses/_self_supervised.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index 03dfef6c8..aa8a9c113 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -25,6 +25,35 @@ def ntxent( ) -> chex.Numeric: """Normalized temperature scaled cross entropy loss (NT-Xent). + Examples: + >>> import jax + >>> import optax + >>> import jax.numpy as jnp + >>> + >>> key = jax.random.key(42) + >>> key1, key2, key3 = jax.random.split(key, 3) + >>> x = jax.random.normal(key1, shape=(4,2)) + >>> labels = jnp.array([0, 0, 1, 1]) + >>> print("input:", x, "\nlabels:", labels) + input: [[-0.9155995 1.5534698 ] + [ 0.2623586 -1.5908985 ] + [-0.15977189 0.480501 ] + [ 0.58389133 0.10497775]] + labels: [0 0 1 1] + >>> + >>> w = jax.random.normal(key2, shape=(2,1)) # params + >>> b = jax.random.normal(key3, shape=(1,)) # params + >>> out = x @ w + b # model + >>> print("Embeddings:", out) + Embeddings: [[-1.0076267] + [-1.2960069] + [-1.1829865] + [-1.3485558]] + >>> + >>> loss = optax.ntxent(out, labels) + >>> print("loss:", loss) + loss: 1.0986123 + References: T. Chen et al `A Simple Framework for Contrastive Learning of Visual Representations `_, 2020 From 037f4926c1f57a2761ce13fee0b3916e488b64cd Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 7 May 2024 09:58:53 -0400 Subject: [PATCH 06/12] rewrite docstring to pass syntax checks --- optax/losses/_self_supervised.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index aa8a9c113..1dbe98356 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -34,22 +34,23 @@ def ntxent( >>> key1, key2, key3 = jax.random.split(key, 3) >>> x = jax.random.normal(key1, shape=(4,2)) >>> labels = jnp.array([0, 0, 1, 1]) - >>> print("input:", x, "\nlabels:", labels) + >>> print("input:", x) input: [[-0.9155995 1.5534698 ] [ 0.2623586 -1.5908985 ] [-0.15977189 0.480501 ] [ 0.58389133 0.10497775]] + >>> print("labels:", labels) labels: [0 0 1 1] >>> >>> w = jax.random.normal(key2, shape=(2,1)) # params >>> b = jax.random.normal(key3, shape=(1,)) # params >>> out = x @ w + b # model + >>> >>> print("Embeddings:", out) Embeddings: [[-1.0076267] [-1.2960069] [-1.1829865] [-1.3485558]] - >>> >>> loss = optax.ntxent(out, labels) >>> print("loss:", loss) loss: 1.0986123 From b1d8eb35877c746230e27a4061afd4f453a67770 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 7 May 2024 12:59:45 -0400 Subject: [PATCH 07/12] truncated decimal points in doctest --- optax/losses/_self_supervised.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index 1dbe98356..ad2d125b1 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -30,15 +30,16 @@ def ntxent( >>> import optax >>> import jax.numpy as jnp >>> - >>> key = jax.random.key(42) + >>> key = jax.random.key(12345) >>> key1, key2, key3 = jax.random.split(key, 3) >>> x = jax.random.normal(key1, shape=(4,2)) >>> labels = jnp.array([0, 0, 1, 1]) - >>> print("input:", x) - input: [[-0.9155995 1.5534698 ] - [ 0.2623586 -1.5908985 ] - [-0.15977189 0.480501 ] - [ 0.58389133 0.10497775]] + >>> + >>> print("input:", jnp.around(x, 2)) + input: [[-1.22 0.42] + [ 0.64 0.68] + [ 1.75 -0.01] + [ 0.37 0.44]] >>> print("labels:", labels) labels: [0 0 1 1] >>> @@ -46,14 +47,14 @@ def ntxent( >>> b = jax.random.normal(key3, shape=(1,)) # params >>> out = x @ w + b # model >>> - >>> print("Embeddings:", out) - Embeddings: [[-1.0076267] - [-1.2960069] - [-1.1829865] - [-1.3485558]] + >>> print("Embeddings:", jnp.around(out, 2)) + Embeddings: [[1.39 ] + [1.5799999] + [0.34 ] + [1.23 ]] >>> loss = optax.ntxent(out, labels) - >>> print("loss:", loss) - loss: 1.0986123 + >>> print("loss:", jnp.round(loss, 2)) + loss: 1.1 References: T. Chen et al `A Simple Framework for Contrastive Learning of Visual From 24edd0bb5a4f341456e5c2a44be9a4684d7a1347 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 7 May 2024 14:31:10 -0400 Subject: [PATCH 08/12] added a space in the print statements for the doctest --- optax/losses/_self_supervised.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index ad2d125b1..7c1e44e7f 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -35,26 +35,25 @@ def ntxent( >>> x = jax.random.normal(key1, shape=(4,2)) >>> labels = jnp.array([0, 0, 1, 1]) >>> - >>> print("input:", jnp.around(x, 2)) - input: [[-1.22 0.42] - [ 0.64 0.68] - [ 1.75 -0.01] - [ 0.37 0.44]] - >>> print("labels:", labels) + >>> print("input:", x) + input: [[-0.9155995 1.5534698 ] + [ 0.2623586 -1.5908985 ] + [-0.15977189 0.480501 ] + [ 0.58389133 0.10497775]] labels: [0 0 1 1] >>> >>> w = jax.random.normal(key2, shape=(2,1)) # params >>> b = jax.random.normal(key3, shape=(1,)) # params >>> out = x @ w + b # model >>> - >>> print("Embeddings:", jnp.around(out, 2)) - Embeddings: [[1.39 ] - [1.5799999] - [0.34 ] - [1.23 ]] + >>> print("Embeddings:", 2) + Embeddings: [[-1.0076267] + [-1.2960069] + [-1.1829865] + [-1.3485558]] >>> loss = optax.ntxent(out, labels) - >>> print("loss:", jnp.round(loss, 2)) - loss: 1.1 + >>> print("loss:", loss) + loss: 1.0986123 References: T. Chen et al `A Simple Framework for Contrastive Learning of Visual From bb325d7dcb1341488065875558f377b9510075f8 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 7 May 2024 15:41:24 -0400 Subject: [PATCH 09/12] minor fixes to doctest --- optax/losses/_self_supervised.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index 7c1e44e7f..d0b3cbcd6 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -40,13 +40,14 @@ def ntxent( [ 0.2623586 -1.5908985 ] [-0.15977189 0.480501 ] [ 0.58389133 0.10497775]] + >>> print("labels:", labels) labels: [0 0 1 1] >>> >>> w = jax.random.normal(key2, shape=(2,1)) # params >>> b = jax.random.normal(key3, shape=(1,)) # params >>> out = x @ w + b # model >>> - >>> print("Embeddings:", 2) + >>> print("Embeddings:", out) Embeddings: [[-1.0076267] [-1.2960069] [-1.1829865] From a98cec8a596bd1703af6ac10dbe371d0f06b9edb Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 7 May 2024 15:57:57 -0400 Subject: [PATCH 10/12] changed rng key --- optax/losses/_self_supervised.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index d0b3cbcd6..7f82afb80 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -30,7 +30,7 @@ def ntxent( >>> import optax >>> import jax.numpy as jnp >>> - >>> key = jax.random.key(12345) + >>> key = jax.random.key(42) >>> key1, key2, key3 = jax.random.split(key, 3) >>> x = jax.random.normal(key1, shape=(4,2)) >>> labels = jnp.array([0, 0, 1, 1]) From d05337b89948609fa18880ea3fe0f2b12e6c74ae Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 31 May 2024 11:13:54 -0400 Subject: [PATCH 11/12] changed eposilon to match dtype --- optax/losses/_self_supervised.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index 7f82afb80..956b55016 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -18,6 +18,7 @@ from jax import lax import jax.numpy as jnp from optax.losses import _regression +import numpy as np def ntxent( @@ -86,7 +87,7 @@ def ntxent( xcs = ( _regression.cosine_similarity( embeddings[None, :, :], embeddings[:, None, :], - epsilon=1e-12 + epsilon=np.finfo(embeddings.dtype).eps ) / temperature ) From 1bf4b2eb46255a66c8ea8a400f8458ce2fafb04c Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 10 Jun 2024 09:41:08 -0400 Subject: [PATCH 12/12] changed np to jnp to get rid of numpy import --- optax/losses/_self_supervised.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index 956b55016..9d6aab66f 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -18,7 +18,6 @@ from jax import lax import jax.numpy as jnp from optax.losses import _regression -import numpy as np def ntxent( @@ -87,7 +86,7 @@ def ntxent( xcs = ( _regression.cosine_similarity( embeddings[None, :, :], embeddings[:, None, :], - epsilon=np.finfo(embeddings.dtype).eps + epsilon=jnp.finfo(embeddings.dtype).eps ) / temperature )