From 4f2d436e6777707d710522d6dc664e86092cefb3 Mon Sep 17 00:00:00 2001 From: Matteo Hessel Date: Fri, 15 Mar 2024 07:06:31 -0700 Subject: [PATCH] Use optax losses as backend for jaxopt losses. PiperOrigin-RevId: 616117244 --- jaxopt/_src/loss.py | 92 +++++++-------------------------------------- jaxopt/loss.py | 6 ++- requirements.txt | 1 + 3 files changed, 20 insertions(+), 79 deletions(-) diff --git a/jaxopt/_src/loss.py b/jaxopt/_src/loss.py index c48aca96..1574b24a 100644 --- a/jaxopt/_src/loss.py +++ b/jaxopt/_src/loss.py @@ -17,10 +17,10 @@ from typing import Callable import jax -from jax.nn import softplus import jax.numpy as jnp -from jax.scipy.special import logsumexp -from jaxopt._src.projection import projection_simplex, projection_hypercube +from jaxopt._src.projection import projection_simplex + +from optax import losses as optax_losses # Regression @@ -39,10 +39,7 @@ def huber_loss(target: float, pred: float, delta: float = 1.0) -> float: References: https://en.wikipedia.org/wiki/Huber_loss """ - abs_diff = jnp.abs(target - pred) - return jnp.where(abs_diff > delta, - delta * (abs_diff - .5 * delta), - 0.5 * abs_diff ** 2) + return optax_losses.huber_loss(pred, target, delta) # Binary classification. @@ -56,12 +53,8 @@ def binary_logistic_loss(label: int, logit: float) -> float: Returns: loss value """ - # Softplus is the Fenchel conjugate of the Fermi-Dirac negentropy on [0, 1]. - # softplus = proba * logit - xlogx(proba) - xlogx(1 - proba), - # where xlogx(proba) = proba * log(proba). - # Use -log sigmoid(logit) = softplus(-logit) - # and 1 - sigmoid(logit) = sigmoid(-logit). - return softplus(jnp.where(label, -logit, logit)) + return optax_losses.sigmoid_binary_cross_entropy( + jnp.asarray(logit), jnp.asarray(label)) def binary_sparsemax_loss(label: int, logit: float) -> float: @@ -77,59 +70,8 @@ def binary_sparsemax_loss(label: int, logit: float) -> float: Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins, Vlad Niculae. JMLR 2020. (Sec. 4.4) """ - return sparse_plus(jnp.where(label, -logit, logit)) - - -def sparse_plus(x: float) -> float: - r"""Sparse plus function. - - Computes the function: - - .. math:: - - \mathrm{sparse\_plus}(x) = \begin{cases} - 0, & x \leq -1\\ - \frac{1}{4}(x+1)^2, & -1 < x < 1 \\ - x, & 1 \leq x - \end{cases} - - This is the twin function of the softplus activation ensuring a zero output - for inputs less than -1 and a linear output for inputs greater than 1, - while remaining smooth, convex, monotonic by an adequate definition between - -1 and 1. - - Args: - x: input (float) - Returns: - sparse_plus(x) as defined above - """ - return jnp.where(x <= -1.0, 0.0, jnp.where(x >= 1.0, x, (x + 1.0)**2/4)) - - -def sparse_sigmoid(x: float) -> float: - r"""Sparse sigmoid function. - - Computes the function: - - .. math:: - - \mathrm{sparse\_sigmoid}(x) = \begin{cases} - 0, & x \leq -1\\ - \frac{1}{2}(x+1), & -1 < x < 1 \\ - 1, & 1 \leq x - \end{cases} - - This is the twin function of the sigmoid activation ensuring a zero output - for inputs less than -1, a 1 ouput for inputs greater than 1, and a linear - output for inputs between -1 and 1. This is the derivative of the sparse - plus function. - - Args: - x: input (float) - Returns: - sparse_sigmoid(x) as defined above - """ - return 0.5 * projection_hypercube(x + 1.0, 2.0) + return optax_losses.sparsemax_loss( + jnp.asarray(logit), jnp.asarray(label)) def binary_hinge_loss(label: int, score: float) -> float: @@ -144,8 +86,7 @@ def binary_hinge_loss(label: int, score: float) -> float: References: https://en.wikipedia.org/wiki/Hinge_loss """ - signed_label = 2.0 * label - 1.0 - return jnp.maximum(0, 1 - score * signed_label) + return optax_losses.hinge_loss(score, 2.0 * label - 1.0) def binary_perceptron_loss(label: int, score: float) -> float: @@ -160,8 +101,7 @@ def binary_perceptron_loss(label: int, score: float) -> float: References: https://en.wikipedia.org/wiki/Perceptron """ - signed_label = 2.0 * label - 1.0 - return jnp.maximum(0, - score * signed_label) + return optax_losses.perceptron_loss(score, 2.0 * label - 1.0) # Multiclass classification. @@ -175,13 +115,8 @@ def multiclass_logistic_loss(label: int, logits: jnp.ndarray) -> float: Returns: loss value """ - logits = jnp.asarray(logits) - # Logsumexp is the Fenchel conjugate of the Shannon negentropy on the simplex. - # logsumexp = jnp.dot(proba, logits) - jnp.dot(proba, jnp.log(proba)) - # To avoid roundoff error, subtract target inside logsumexp. - # logsumexp(logits) - logits[y] = logsumexp(logits - logits[y]) - logits = (logits - logits[label]).at[label].set(0.0) - return logsumexp(logits) + return optax_losses.softmax_cross_entropy_with_integer_labels( + jnp.asarray(logits), jnp.asarray(label)) def multiclass_sparsemax_loss(label: int, scores: jnp.ndarray) -> float: @@ -272,5 +207,6 @@ def make_fenchel_young_loss(max_fun: Callable[[jnp.ndarray], float]): """ def fy_loss(y_true, scores, *args, **kwargs): - return max_fun(scores, *args, **kwargs) - jnp.vdot(y_true, scores) + return optax_losses.make_fenchel_young_loss(max_fun)( + scores.ravel(), y_true.ravel(), *args, **kwargs) return fy_loss diff --git a/jaxopt/loss.py b/jaxopt/loss.py index 04189be7..29ab04d2 100644 --- a/jaxopt/loss.py +++ b/jaxopt/loss.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import jax +sparse_plus = jax.nn.sparse_plus +spase_sigmoid = jax.nn.sparse_sigmoid + from jaxopt._src.loss import binary_logistic_loss -from jaxopt._src.loss import binary_sparsemax_loss, sparse_plus, sparse_sigmoid +from jaxopt._src.loss import binary_sparsemax_loss from jaxopt._src.loss import huber_loss from jaxopt._src.loss import make_fenchel_young_loss from jaxopt._src.loss import multiclass_logistic_loss diff --git a/requirements.txt b/requirements.txt index 85417b2c..3a319bb5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ jax>=0.2.18 jaxlib>=0.1.69 numpy>=1.18.4 +optax>=0.2.2 scipy>=1.0.0