Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn that in future arguments after the initial (prediction, ground_truth) positional arguments will become keyword-only in optax losses. #863

Merged
merged 1 commit into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""Classification losses."""

import functools
from typing import Optional

import chex
Expand Down Expand Up @@ -138,8 +139,11 @@ def softmax_cross_entropy_with_integer_labels(
return log_normalizers - label_logits


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def poly_loss_cross_entropy(
logits: chex.Array, labels: chex.Array, epsilon: float = 2.0
logits: chex.Array,
labels: chex.Array,
epsilon: float = 2.0
) -> chex.Array:
r"""Computes PolyLoss between logits and labels.

Expand Down Expand Up @@ -236,7 +240,8 @@ def kl_divergence_with_log_targets(


def convex_kl_divergence(
log_predictions: chex.Array, targets: chex.Array
log_predictions: chex.Array,
targets: chex.Array
) -> chex.Array:
"""Computes a convex version of the Kullback-Leibler divergence loss.

Expand All @@ -262,6 +267,7 @@ def convex_kl_divergence(
)


@functools.partial(chex.warn_only_n_pos_args_in_future, n=4)
def ctc_loss_with_forward_probs(
logits: chex.Array,
logit_paddings: chex.Array,
Expand Down Expand Up @@ -392,6 +398,7 @@ def loop_body(prev, x):
return per_seq_loss, logalpha_phi, logalpha_emit


@functools.partial(chex.warn_only_n_pos_args_in_future, n=4)
def ctc_loss(
logits: chex.Array,
logit_paddings: chex.Array,
Expand Down Expand Up @@ -432,6 +439,7 @@ def ctc_loss(
return per_seq_loss


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def sigmoid_focal_loss(
logits: chex.Array,
labels: chex.Array,
Expand Down
4 changes: 2 additions & 2 deletions optax/losses/_classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def setUp(self):
def test_scalar(self, eps, expected):
np.testing.assert_allclose(
self.variant(_classification.poly_loss_cross_entropy)(
self.logits, self.labels, eps
self.logits, self.labels, epsilon=eps
),
expected,
atol=1e-4,
Expand All @@ -163,7 +163,7 @@ def test_scalar(self, eps, expected):
def test_batched(self, eps, expected):
np.testing.assert_allclose(
self.variant(_classification.poly_loss_cross_entropy)(
self.batched_logits, self.batched_labels, eps
self.batched_logits, self.batched_labels, epsilon=eps
),
expected,
atol=1e-4,
Expand Down
6 changes: 5 additions & 1 deletion optax/losses/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""Regression losses."""

import functools
from typing import Optional

import chex
Expand Down Expand Up @@ -76,6 +77,7 @@ def l2_loss(
return 0.5 * squared_error(predictions, targets)


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def huber_loss(
predictions: chex.Array,
targets: Optional[chex.Array] = None,
Expand Down Expand Up @@ -135,6 +137,7 @@ def log_cosh(
return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype)


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def cosine_similarity(
predictions: chex.Array,
targets: chex.Array,
Expand Down Expand Up @@ -170,6 +173,7 @@ def cosine_similarity(
return jnp.sum(unit_targets * unit_predictions, axis=-1)


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def cosine_distance(
predictions: chex.Array,
targets: chex.Array,
Expand All @@ -193,4 +197,4 @@ def cosine_distance(
"""
chex.assert_type([predictions, targets], float)
# cosine distance = 1 - cosine similarity.
return 1. - cosine_similarity(predictions, targets, epsilon)
return 1. - cosine_similarity(predictions, targets, epsilon=epsilon)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ classifiers = [
]
dependencies = [
"absl-py>=0.7.1",
"chex>=0.1.7",
"chex>=0.1.86",
"jax>=0.1.55",
"jaxlib>=0.1.37",
"numpy>=1.18.0",
Expand Down
Loading