Skip to content

Commit

Permalink
Warn that in future arguments after the initial (prediction, ground_t…
Browse files Browse the repository at this point in the history
…ruth) positional arguments will become keyword-only in optax losses.

This is to prevent users from accidentally passing in the wrong order of arguments, and to make interfaces functionally more similar (e.g. after partialling loss-specific hyper-parameters).

PiperOrigin-RevId: 615023064
  • Loading branch information
mtthss authored and OptaxDev committed Mar 19, 2024
1 parent 68b8d79 commit adff4d4
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 6 deletions.
12 changes: 10 additions & 2 deletions optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""Classification losses."""

import functools
from typing import Optional

import chex
Expand Down Expand Up @@ -138,8 +139,11 @@ def softmax_cross_entropy_with_integer_labels(
return log_normalizers - label_logits


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def poly_loss_cross_entropy(
logits: chex.Array, labels: chex.Array, epsilon: float = 2.0
logits: chex.Array,
labels: chex.Array,
epsilon: float = 2.0
) -> chex.Array:
r"""Computes PolyLoss between logits and labels.
Expand Down Expand Up @@ -236,7 +240,8 @@ def kl_divergence_with_log_targets(


def convex_kl_divergence(
log_predictions: chex.Array, targets: chex.Array
log_predictions: chex.Array,
targets: chex.Array
) -> chex.Array:
"""Computes a convex version of the Kullback-Leibler divergence loss.
Expand All @@ -262,6 +267,7 @@ def convex_kl_divergence(
)


@functools.partial(chex.warn_only_n_pos_args_in_future, n=4)
def ctc_loss_with_forward_probs(
logits: chex.Array,
logit_paddings: chex.Array,
Expand Down Expand Up @@ -392,6 +398,7 @@ def loop_body(prev, x):
return per_seq_loss, logalpha_phi, logalpha_emit


@functools.partial(chex.warn_only_n_pos_args_in_future, n=4)
def ctc_loss(
logits: chex.Array,
logit_paddings: chex.Array,
Expand Down Expand Up @@ -432,6 +439,7 @@ def ctc_loss(
return per_seq_loss


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def sigmoid_focal_loss(
logits: chex.Array,
labels: chex.Array,
Expand Down
4 changes: 2 additions & 2 deletions optax/losses/_classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def setUp(self):
def test_scalar(self, eps, expected):
np.testing.assert_allclose(
self.variant(_classification.poly_loss_cross_entropy)(
self.logits, self.labels, eps
self.logits, self.labels, epsilon=eps
),
expected,
atol=1e-4,
Expand All @@ -163,7 +163,7 @@ def test_scalar(self, eps, expected):
def test_batched(self, eps, expected):
np.testing.assert_allclose(
self.variant(_classification.poly_loss_cross_entropy)(
self.batched_logits, self.batched_labels, eps
self.batched_logits, self.batched_labels, epsilon=eps
),
expected,
atol=1e-4,
Expand Down
6 changes: 5 additions & 1 deletion optax/losses/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""Regression losses."""

import functools
from typing import Optional

import chex
Expand Down Expand Up @@ -76,6 +77,7 @@ def l2_loss(
return 0.5 * squared_error(predictions, targets)


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def huber_loss(
predictions: chex.Array,
targets: Optional[chex.Array] = None,
Expand Down Expand Up @@ -135,6 +137,7 @@ def log_cosh(
return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype)


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def cosine_similarity(
predictions: chex.Array,
targets: chex.Array,
Expand Down Expand Up @@ -170,6 +173,7 @@ def cosine_similarity(
return jnp.sum(unit_targets * unit_predictions, axis=-1)


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def cosine_distance(
predictions: chex.Array,
targets: chex.Array,
Expand All @@ -193,4 +197,4 @@ def cosine_distance(
"""
chex.assert_type([predictions, targets], float)
# cosine distance = 1 - cosine similarity.
return 1. - cosine_similarity(predictions, targets, epsilon)
return 1. - cosine_similarity(predictions, targets, epsilon=epsilon)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ classifiers = [
]
dependencies = [
"absl-py>=0.7.1",
"chex>=0.1.7",
"chex>=0.1.86",
"jax>=0.1.55",
"jaxlib>=0.1.37",
"numpy>=1.18.0",
Expand Down

0 comments on commit adff4d4

Please sign in to comment.