Skip to content

Commit

Permalink
Arguments after the initial (prediction, ground_truth) positional arg…
Browse files Browse the repository at this point in the history
…uments should be keyword-only in optax losses.

This is to prevent users from accidentally passing in the wrong order of arguments, and to make interfaces functionally more similar (e.g. after partialling loss-specific hyper-parameters).

PiperOrigin-RevId: 615023064
  • Loading branch information
mtthss authored and OptaxDev committed Mar 13, 2024
1 parent 269c9dc commit 7095640
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
13 changes: 11 additions & 2 deletions optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
# ==============================================================================
"""Classification losses."""

import functools
from typing import Optional

import chex
import jax
import jax.numpy as jnp


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def sigmoid_binary_cross_entropy(
logits,
labels,
Expand Down Expand Up @@ -139,7 +141,10 @@ def softmax_cross_entropy_with_integer_labels(


def poly_loss_cross_entropy(
logits: chex.Array, labels: chex.Array, epsilon: float = 2.0
logits: chex.Array,
labels: chex.Array,
*,
epsilon: float = 2.0
) -> chex.Array:
r"""Computes PolyLoss between logits and labels.
Expand Down Expand Up @@ -236,7 +241,8 @@ def kl_divergence_with_log_targets(


def convex_kl_divergence(
log_predictions: chex.Array, targets: chex.Array
log_predictions: chex.Array,
targets: chex.Array
) -> chex.Array:
"""Computes a convex version of the Kullback-Leibler divergence loss.
Expand Down Expand Up @@ -267,6 +273,7 @@ def ctc_loss_with_forward_probs(
logit_paddings: chex.Array,
labels: chex.Array,
label_paddings: chex.Array,
*,
blank_id: int = 0,
log_epsilon: float = -1e5
) -> tuple[chex.Array, chex.Array, chex.Array]:
Expand Down Expand Up @@ -397,6 +404,7 @@ def ctc_loss(
logit_paddings: chex.Array,
labels: chex.Array,
label_paddings: chex.Array,
*,
blank_id: int = 0,
log_epsilon: float = -1e5
) -> chex.Array:
Expand Down Expand Up @@ -435,6 +443,7 @@ def ctc_loss(
def sigmoid_focal_loss(
logits: chex.Array,
labels: chex.Array,
*,
alpha: Optional[float] = None,
gamma: float = 2.,
) -> chex.Array:
Expand Down
4 changes: 2 additions & 2 deletions optax/losses/_classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def setUp(self):
def test_scalar(self, eps, expected):
np.testing.assert_allclose(
self.variant(_classification.poly_loss_cross_entropy)(
self.logits, self.labels, eps
self.logits, self.labels, epsilon=eps
),
expected,
atol=1e-4,
Expand All @@ -163,7 +163,7 @@ def test_scalar(self, eps, expected):
def test_batched(self, eps, expected):
np.testing.assert_allclose(
self.variant(_classification.poly_loss_cross_entropy)(
self.batched_logits, self.batched_labels, eps
self.batched_logits, self.batched_labels, epsilon=eps
),
expected,
atol=1e-4,
Expand Down
5 changes: 4 additions & 1 deletion optax/losses/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def l2_loss(
def huber_loss(
predictions: chex.Array,
targets: Optional[chex.Array] = None,
*,
delta: float = 1.
) -> chex.Array:
"""Huber loss, similar to L2 loss close to zero, L1 loss away from zero.
Expand Down Expand Up @@ -138,6 +139,7 @@ def log_cosh(
def cosine_similarity(
predictions: chex.Array,
targets: chex.Array,
*,
epsilon: float = 0.,
) -> chex.Array:
r"""Computes the cosine similarity between targets and predictions.
Expand Down Expand Up @@ -173,6 +175,7 @@ def cosine_similarity(
def cosine_distance(
predictions: chex.Array,
targets: chex.Array,
*,
epsilon: float = 0.,
) -> chex.Array:
r"""Computes the cosine distance between targets and predictions.
Expand All @@ -193,4 +196,4 @@ def cosine_distance(
"""
chex.assert_type([predictions, targets], float)
# cosine distance = 1 - cosine similarity.
return 1. - cosine_similarity(predictions, targets, epsilon)
return 1. - cosine_similarity(predictions, targets, epsilon=epsilon)

0 comments on commit 7095640

Please sign in to comment.