Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use optax losses as backend for jaxopt losses. #586

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 14 additions & 78 deletions jaxopt/_src/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from typing import Callable

import jax
from jax.nn import softplus
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jaxopt._src.projection import projection_simplex, projection_hypercube
from jaxopt._src.projection import projection_simplex

from optax import losses as optax_losses


# Regression
Expand All @@ -39,10 +39,7 @@ def huber_loss(target: float, pred: float, delta: float = 1.0) -> float:
References:
https://en.wikipedia.org/wiki/Huber_loss
"""
abs_diff = jnp.abs(target - pred)
return jnp.where(abs_diff > delta,
delta * (abs_diff - .5 * delta),
0.5 * abs_diff ** 2)
return optax_losses.huber_loss(pred, target, delta)

# Binary classification.

Expand All @@ -56,12 +53,8 @@ def binary_logistic_loss(label: int, logit: float) -> float:
Returns:
loss value
"""
# Softplus is the Fenchel conjugate of the Fermi-Dirac negentropy on [0, 1].
# softplus = proba * logit - xlogx(proba) - xlogx(1 - proba),
# where xlogx(proba) = proba * log(proba).
# Use -log sigmoid(logit) = softplus(-logit)
# and 1 - sigmoid(logit) = sigmoid(-logit).
return softplus(jnp.where(label, -logit, logit))
return optax_losses.sigmoid_binary_cross_entropy(
jnp.asarray(logit), jnp.asarray(label))


def binary_sparsemax_loss(label: int, logit: float) -> float:
Expand All @@ -77,59 +70,8 @@ def binary_sparsemax_loss(label: int, logit: float) -> float:
Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins,
Vlad Niculae. JMLR 2020. (Sec. 4.4)
"""
return sparse_plus(jnp.where(label, -logit, logit))


def sparse_plus(x: float) -> float:
r"""Sparse plus function.

Computes the function:

.. math::

\mathrm{sparse\_plus}(x) = \begin{cases}
0, & x \leq -1\\
\frac{1}{4}(x+1)^2, & -1 < x < 1 \\
x, & 1 \leq x
\end{cases}

This is the twin function of the softplus activation ensuring a zero output
for inputs less than -1 and a linear output for inputs greater than 1,
while remaining smooth, convex, monotonic by an adequate definition between
-1 and 1.

Args:
x: input (float)
Returns:
sparse_plus(x) as defined above
"""
return jnp.where(x <= -1.0, 0.0, jnp.where(x >= 1.0, x, (x + 1.0)**2/4))


def sparse_sigmoid(x: float) -> float:
r"""Sparse sigmoid function.

Computes the function:

.. math::

\mathrm{sparse\_sigmoid}(x) = \begin{cases}
0, & x \leq -1\\
\frac{1}{2}(x+1), & -1 < x < 1 \\
1, & 1 \leq x
\end{cases}

This is the twin function of the sigmoid activation ensuring a zero output
for inputs less than -1, a 1 ouput for inputs greater than 1, and a linear
output for inputs between -1 and 1. This is the derivative of the sparse
plus function.

Args:
x: input (float)
Returns:
sparse_sigmoid(x) as defined above
"""
return 0.5 * projection_hypercube(x + 1.0, 2.0)
return optax_losses.sparsemax_loss(
jnp.asarray(logit), jnp.asarray(label))


def binary_hinge_loss(label: int, score: float) -> float:
Expand All @@ -144,8 +86,7 @@ def binary_hinge_loss(label: int, score: float) -> float:
References:
https://en.wikipedia.org/wiki/Hinge_loss
"""
signed_label = 2.0 * label - 1.0
return jnp.maximum(0, 1 - score * signed_label)
return optax_losses.hinge_loss(score, 2.0 * label - 1.0)


def binary_perceptron_loss(label: int, score: float) -> float:
Expand All @@ -160,8 +101,7 @@ def binary_perceptron_loss(label: int, score: float) -> float:
References:
https://en.wikipedia.org/wiki/Perceptron
"""
signed_label = 2.0 * label - 1.0
return jnp.maximum(0, - score * signed_label)
return optax_losses.perceptron_loss(score, 2.0 * label - 1.0)

# Multiclass classification.

Expand All @@ -175,13 +115,8 @@ def multiclass_logistic_loss(label: int, logits: jnp.ndarray) -> float:
Returns:
loss value
"""
logits = jnp.asarray(logits)
# Logsumexp is the Fenchel conjugate of the Shannon negentropy on the simplex.
# logsumexp = jnp.dot(proba, logits) - jnp.dot(proba, jnp.log(proba))
# To avoid roundoff error, subtract target inside logsumexp.
# logsumexp(logits) - logits[y] = logsumexp(logits - logits[y])
logits = (logits - logits[label]).at[label].set(0.0)
return logsumexp(logits)
return optax_losses.softmax_cross_entropy_with_integer_labels(
jnp.asarray(logits), jnp.asarray(label))


def multiclass_sparsemax_loss(label: int, scores: jnp.ndarray) -> float:
Expand Down Expand Up @@ -272,5 +207,6 @@ def make_fenchel_young_loss(max_fun: Callable[[jnp.ndarray], float]):
"""

def fy_loss(y_true, scores, *args, **kwargs):
return max_fun(scores, *args, **kwargs) - jnp.vdot(y_true, scores)
return optax_losses.make_fenchel_young_loss(max_fun)(
scores.ravel(), y_true.ravel(), *args, **kwargs)
return fy_loss
6 changes: 5 additions & 1 deletion jaxopt/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
sparse_plus = jax.nn.sparse_plus
spase_sigmoid = jax.nn.sparse_sigmoid

from jaxopt._src.loss import binary_logistic_loss
from jaxopt._src.loss import binary_sparsemax_loss, sparse_plus, sparse_sigmoid
from jaxopt._src.loss import binary_sparsemax_loss
from jaxopt._src.loss import huber_loss
from jaxopt._src.loss import make_fenchel_young_loss
from jaxopt._src.loss import multiclass_logistic_loss
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
jax>=0.2.18
jaxlib>=0.1.69
numpy>=1.18.4
optax>=0.2.2
scipy>=1.0.0
Loading