Skip to content

Commit

Permalink
Merge pull request #912 from carlosgmartin:loss_axis_where
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668475651
  • Loading branch information
OptaxDev committed Aug 28, 2024
2 parents 40b3a71 + 2f1aec0 commit eba3733
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 34 deletions.
20 changes: 17 additions & 3 deletions docs/api/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ Losses
hinge_loss
huber_loss
kl_divergence
kl_divergence_with_log_targets
l2_loss
log_cosh
make_fenchel_young_loss
multiclass_hinge_loss
multiclass_perceptron_loss
multiclass_sparsemax_loss
ntxent
perceptron_loss
poly_loss_cross_entropy
ranking_softmax_loss
safe_softmax_cross_entropy
sigmoid_binary_cross_entropy
sigmoid_focal_loss
Expand Down Expand Up @@ -46,6 +50,10 @@ Connectionist temporal classification loss
.. autofunction:: ctc_loss
.. autofunction:: ctc_loss_with_forward_probs

Fenchel Young loss
~~~~~~~~~~~~~~~~~~
.. autofunction:: make_fenchel_young_loss

Hinge loss
~~~~~~~~~~
.. autofunction:: hinge_loss
Expand All @@ -58,6 +66,7 @@ Huber loss
Kullback-Leibler divergence
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: kl_divergence
.. autofunction:: kl_divergence_with_log_targets

L2 Squared loss
~~~~~~~~~~~~~~~
Expand All @@ -72,11 +81,19 @@ Normalized temperature scaled cross-entropy (NT-Xent) loss
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ntxent

Poly loss cross-entropy
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: poly_loss_cross_entropy

Perceptron
~~~~~~~~~~~
.. autofunction:: perceptron_loss
.. autofunction:: multiclass_perceptron_loss

Ranking softmax loss
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ranking_softmax_loss

Sigmoid binary cross-entropy
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: sigmoid_binary_cross_entropy
Expand All @@ -99,6 +116,3 @@ Sparsemax
~~~~~~~~~
.. autofunction:: sparsemax_loss
.. autofunction:: multiclass_sparsemax_loss



82 changes: 66 additions & 16 deletions optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Classification losses."""

import functools
from typing import Optional
from typing import Optional, Union

import chex
import jax
Expand Down Expand Up @@ -209,6 +209,8 @@ def safe_softmax_cross_entropy(
def softmax_cross_entropy(
logits: chex.Array,
labels: chex.Array,
axis: Union[int, tuple[int, ...], None] = -1,
where: Union[chex.Array, None] = None,
) -> chex.Array:
r"""Computes the softmax cross entropy between sets of logits and labels.
Expand Down Expand Up @@ -249,6 +251,8 @@ def softmax_cross_entropy(
num_classes]``.
labels: One-hot encoded labels, with shape `[batch_size, num_classes]`. Each
row represents the true class distribution for a single example.
axis: Axis or axes along which to compute.
where: Elements to include in the computation.
Returns:
Cross-entropy between each prediction and the corresponding target
Expand All @@ -261,14 +265,20 @@ def softmax_cross_entropy(
:func:`optax.losses.safe_softmax_cross_entropy` provides an alternative
implementation that differs on how ``logits=-inf`` are handled.
.. versionchanged:: 0.2.4
Added ``axis`` and ``where`` arguments.
"""
chex.assert_type([logits], float)
return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1)
log_probs = jax.nn.log_softmax(logits, axis, where)
return -(labels * log_probs).sum(axis, where=where)


def softmax_cross_entropy_with_integer_labels(
logits: chex.Array,
labels: chex.Array,
axis: Union[int, tuple[int, ...], None] = -1,
where: Union[chex.Array, None] = None,
) -> chex.Array:
r"""Computes softmax cross entropy between the logits and integer labels.
Expand Down Expand Up @@ -307,6 +317,8 @@ def softmax_cross_entropy_with_integer_labels(
labels: Integers specifying the correct class for each input, with shape
``[batch_size]``. Class labels are assumed to be between 0 and
``num_classes - 1`` inclusive.
axis: Axis or axes along which to compute.
where: Elements to include in the computation.
Returns:
Cross-entropy between each prediction and the corresponding target
Expand All @@ -315,16 +327,23 @@ def softmax_cross_entropy_with_integer_labels(
.. seealso:: This function is similar to
:func:`optax.losses.softmax_cross_entropy`, but accepts integer labels
instead of one-hot labels.
.. versionchanged:: 0.2.4
Added ``axis`` and ``where`` arguments.
"""
chex.assert_type([logits], float)
chex.assert_type([labels], int)
# This is like jnp.take_along_axis(jax.nn.log_softmax(...), ...) except that
# we avoid subtracting the normalizer from all values, just from the values
# for the correct labels.
logits_max = jnp.max(logits, axis=-1, keepdims=True)
logits_max = jnp.max(
logits, axis, keepdims=True, where=where, initial=-jnp.inf
)
logits -= jax.lax.stop_gradient(logits_max)
label_logits = jnp.take_along_axis(logits, labels[..., None], axis=-1)[..., 0]
log_normalizers = jnp.log(jnp.sum(jnp.exp(logits), axis=-1))
label_logits = jnp.take_along_axis(
logits, jnp.expand_dims(labels, axis), axis=axis
).take(0, axis=axis)
log_normalizers = jnp.log(jnp.sum(jnp.exp(logits), axis=axis, where=where))
return log_normalizers - label_logits


Expand Down Expand Up @@ -388,7 +407,9 @@ def multiclass_perceptron_loss(
def poly_loss_cross_entropy(
logits: chex.Array,
labels: chex.Array,
epsilon: float = 2.0
epsilon: float = 2.0,
axis: Union[int, tuple[int, ...], None] = -1,
where: Union[chex.Array, None] = None,
) -> chex.Array:
r"""Computes PolyLoss between logits and labels.
Expand Down Expand Up @@ -421,20 +442,30 @@ def poly_loss_cross_entropy(
- For the 2d Instance Segmentation and object detection, epsilon = -1.0.
- It is also recommended to adjust this value based on the task, e.g. by
using grid search.
axis: Axis or axes along which to compute.
where: Elements to include in the computation.
Returns:
Poly loss between each prediction and the corresponding target
distributions, with shape `[...]`.
.. versionchanged:: 0.2.4
Added ``axis`` and ``where`` arguments.
"""
chex.assert_type([logits, labels], float)
one_minus_pt = jnp.sum(labels * (1 - jax.nn.softmax(logits)), axis=-1)
cross_entropy = softmax_cross_entropy(logits=logits, labels=labels)
p = jax.nn.softmax(logits, axis=axis, where=where)
one_minus_pt = jnp.sum(labels * (1 - p), axis=axis, where=where)
cross_entropy = softmax_cross_entropy(
logits=logits, labels=labels, axis=axis, where=where
)
return cross_entropy + epsilon * one_minus_pt


def kl_divergence(
log_predictions: chex.Array,
targets: chex.Array
targets: chex.Array,
axis: Union[int, tuple[int, ...], None] = -1,
where: Union[chex.Array, None] = None,
) -> chex.Array:
"""Computes the Kullback-Leibler divergence (relative entropy) loss.
Expand All @@ -449,21 +480,28 @@ def kl_divergence(
dim]. Expected to be in the log-space to avoid underflow.
targets: Probabilities of target distribution with shape [..., dim].
Expected to be strictly positive.
axis: Axis or axes along which to compute.
where: Elements to include in the computation.
Returns:
Kullback-Leibler divergence of predicted distribution from target
distribution with shape [...].
.. versionchanged:: 0.2.4
Added ``axis`` and ``where`` arguments.
"""
chex.assert_type([log_predictions, targets], float)
loss = targets * (
jnp.where(targets == 0, 0, jnp.log(targets)) - log_predictions
)
return jnp.sum(loss, axis=-1)
return jnp.sum(loss, axis=axis, where=where)


def kl_divergence_with_log_targets(
log_predictions: chex.Array,
log_targets: chex.Array
log_targets: chex.Array,
axis: Union[int, tuple[int, ...], None] = -1,
where: Union[chex.Array, None] = None,
) -> chex.Array:
"""Computes the Kullback-Leibler divergence (relative entropy) loss.
Expand All @@ -474,19 +512,26 @@ def kl_divergence_with_log_targets(
[..., dim]. Expected to be in the log-space to avoid underflow.
log_targets: Probabilities of target distribution with shape [..., dim].
Expected to be in the log-space.
axis: Axis or axes along which to compute.
where: Elements to include in the computation.
Returns:
Kullback-Leibler divergence of predicted distribution from target
distribution with shape [...].
.. versionchanged:: 0.2.4
Added ``axis`` and ``where`` arguments.
"""
chex.assert_type([log_predictions, log_targets], float)
loss = jnp.exp(log_targets) * (log_targets - log_predictions)
return jnp.sum(loss, axis=-1)
return jnp.sum(loss, axis=axis, where=where)


def convex_kl_divergence(
log_predictions: chex.Array,
targets: chex.Array
targets: chex.Array,
axis: Union[int, tuple[int, ...], None] = -1,
where: Union[chex.Array, None] = None,
) -> chex.Array:
"""Computes a convex version of the Kullback-Leibler divergence loss.
Expand All @@ -502,14 +547,19 @@ def convex_kl_divergence(
dim]. Expected to be in the log-space to avoid underflow.
targets: Probabilities of target distribution with shape [..., dim].
Expected to be strictly positive.
axis: Axis or axes along which to compute.
where: Elements to include in the computation.
Returns:
Kullback-Leibler divergence of predicted distribution from target
distribution with shape [...].
.. versionchanged:: 0.2.4
Added ``axis`` and ``where`` arguments.
"""
return kl_divergence(log_predictions, targets) + jnp.sum(
jnp.exp(log_predictions) - targets, axis=-1
)
x = kl_divergence(log_predictions, targets, axis=axis, where=where)
y = jnp.sum(jnp.exp(log_predictions) - targets, axis=axis, where=where)
return x + y


@functools.partial(chex.warn_only_n_pos_args_in_future, n=4)
Expand Down
Loading

0 comments on commit eba3733

Please sign in to comment.