From 46cac2c10f9a012fbe36f76e9ecb65ebdf35299f Mon Sep 17 00:00:00 2001 From: Mathieu Blondel Date: Mon, 29 Aug 2022 14:47:21 +0200 Subject: [PATCH] Small optim in l2-regularized semi-dual. --- jaxopt/_src/projection.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/jaxopt/_src/projection.py b/jaxopt/_src/projection.py index e9dce7bd..04ab9bf3 100644 --- a/jaxopt/_src/projection.py +++ b/jaxopt/_src/projection.py @@ -392,13 +392,18 @@ def projection_box_section(x: jnp.ndarray, def _max_l2(x, marginal_b, gamma): scale = gamma * marginal_b - p = projection_simplex(x / scale) + x_scale = x / scale + p = projection_simplex(x_scale) + # From Danskin's theorem, we do not need to backpropagate + # through projection_simplex. + p = jax.lax.stop_gradient(p) return jnp.dot(x, p) - 0.5 * scale * jnp.dot(p, p) def _max_ent(x, marginal_b, gamma): return gamma * logsumexp(x / gamma) - gamma * jnp.log(marginal_b) + _max_l2_vmap = jax.vmap(_max_l2, in_axes=(1, 0, None)) _max_l2_grad_vmap = jax.vmap(jax.grad(_max_l2), in_axes=(1, 0, None)) @@ -771,4 +776,4 @@ def kl_projection_birkhoff(sim_matrix: jnp.ndarray, return kl_projection_transport(sim_matrix=sim_matrix, marginals=(marginals_a, marginals_b), make_solver=make_solver, - use_semi_dual=use_semi_dual) \ No newline at end of file + use_semi_dual=use_semi_dual)