diff --git a/optax/losses/_classification.py b/optax/losses/_classification.py index 634cea1a7..1e077ea12 100644 --- a/optax/losses/_classification.py +++ b/optax/losses/_classification.py @@ -14,6 +14,7 @@ # ============================================================================== """Classification losses.""" +import functools from typing import Optional import chex @@ -138,8 +139,11 @@ def softmax_cross_entropy_with_integer_labels( return log_normalizers - label_logits +@functools.partial(chex.warn_only_n_pos_args_in_future, n=2) def poly_loss_cross_entropy( - logits: chex.Array, labels: chex.Array, epsilon: float = 2.0 + logits: chex.Array, + labels: chex.Array, + epsilon: float = 2.0 ) -> chex.Array: r"""Computes PolyLoss between logits and labels. @@ -236,7 +240,8 @@ def kl_divergence_with_log_targets( def convex_kl_divergence( - log_predictions: chex.Array, targets: chex.Array + log_predictions: chex.Array, + targets: chex.Array ) -> chex.Array: """Computes a convex version of the Kullback-Leibler divergence loss. @@ -262,6 +267,7 @@ def convex_kl_divergence( ) +@functools.partial(chex.warn_only_n_pos_args_in_future, n=4) def ctc_loss_with_forward_probs( logits: chex.Array, logit_paddings: chex.Array, @@ -392,6 +398,7 @@ def loop_body(prev, x): return per_seq_loss, logalpha_phi, logalpha_emit +@functools.partial(chex.warn_only_n_pos_args_in_future, n=4) def ctc_loss( logits: chex.Array, logit_paddings: chex.Array, @@ -432,6 +439,7 @@ def ctc_loss( return per_seq_loss +@functools.partial(chex.warn_only_n_pos_args_in_future, n=2) def sigmoid_focal_loss( logits: chex.Array, labels: chex.Array, diff --git a/optax/losses/_classification_test.py b/optax/losses/_classification_test.py index b1d0d53b9..db1af2d3d 100644 --- a/optax/losses/_classification_test.py +++ b/optax/losses/_classification_test.py @@ -144,7 +144,7 @@ def setUp(self): def test_scalar(self, eps, expected): np.testing.assert_allclose( self.variant(_classification.poly_loss_cross_entropy)( - self.logits, self.labels, eps + self.logits, self.labels, epsilon=eps ), expected, atol=1e-4, @@ -163,7 +163,7 @@ def test_scalar(self, eps, expected): def test_batched(self, eps, expected): np.testing.assert_allclose( self.variant(_classification.poly_loss_cross_entropy)( - self.batched_logits, self.batched_labels, eps + self.batched_logits, self.batched_labels, epsilon=eps ), expected, atol=1e-4, diff --git a/optax/losses/_regression.py b/optax/losses/_regression.py index f0f4c9c07..f39beb1dd 100644 --- a/optax/losses/_regression.py +++ b/optax/losses/_regression.py @@ -14,6 +14,7 @@ # ============================================================================== """Regression losses.""" +import functools from typing import Optional import chex @@ -76,6 +77,7 @@ def l2_loss( return 0.5 * squared_error(predictions, targets) +@functools.partial(chex.warn_only_n_pos_args_in_future, n=2) def huber_loss( predictions: chex.Array, targets: Optional[chex.Array] = None, @@ -135,6 +137,7 @@ def log_cosh( return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype) +@functools.partial(chex.warn_only_n_pos_args_in_future, n=2) def cosine_similarity( predictions: chex.Array, targets: chex.Array, @@ -170,6 +173,7 @@ def cosine_similarity( return jnp.sum(unit_targets * unit_predictions, axis=-1) +@functools.partial(chex.warn_only_n_pos_args_in_future, n=2) def cosine_distance( predictions: chex.Array, targets: chex.Array, @@ -193,4 +197,4 @@ def cosine_distance( """ chex.assert_type([predictions, targets], float) # cosine distance = 1 - cosine similarity. - return 1. - cosine_similarity(predictions, targets, epsilon) + return 1. - cosine_similarity(predictions, targets, epsilon=epsilon) diff --git a/pyproject.toml b/pyproject.toml index 126a6caf4..67b97acc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ classifiers = [ ] dependencies = [ "absl-py>=0.7.1", - "chex>=0.1.7", + "chex>=0.1.86", "jax>=0.1.55", "jaxlib>=0.1.37", "numpy>=1.18.0",