Skip to content

Commit

Permalink
Merge pull request #412 from froystig:jit-bisect-test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 516917371
  • Loading branch information
JAXopt authors committed Mar 15, 2023
2 parents cb6ed9a + e196ece commit ea8e0f1
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions tests/bisection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,15 @@ def _optimality_fun_proj_simplex(tau, x, s):
return jnp.sum(jnp.maximum(x - tau, 0)) - s


def _threshold_proj_simplex(x, s):
def _threshold_proj_simplex(bisect, x, s=1.0):
return bisect.run(None, x, s).params


def _projection_simplex_bisect(bisect, x, s=1.0):
return jnp.maximum(x - _threshold_proj_simplex(bisect, x, s), 0)


def _projection_simplex_bisect_setup(x, s=1.0):
# tau = max(x) => tau >= x_i for all i
# => x_i - tau <= 0 for all i
# => maximum(x_i - tau, 0) = 0 for all i
Expand All @@ -45,33 +53,30 @@ def _threshold_proj_simplex(x, s):
# where tau = tau' - s / len(x)
lower = jax.lax.stop_gradient(jnp.min(x)) - s / len(x)

bisect = Bisection(optimality_fun=_optimality_fun_proj_simplex,
lower=lower, upper=upper, check_bracket=False)
return bisect.run(None, x, s).params


def _projection_simplex_bisect(x, s=1.0):
return jnp.maximum(x - _threshold_proj_simplex(x, s), 0)
return Bisection(optimality_fun=_optimality_fun_proj_simplex,
lower=lower, upper=upper, check_bracket=False)


class BisectionTest(test_util.JaxoptTestCase):

def test_bisect(self):
rng = onp.random.RandomState(0)

jitted_fun = jax.jit(_projection_simplex_bisect)
_projection_simplex_bisect_jitted = jax.jit(
_projection_simplex_bisect, static_argnums=0)

for _ in range(10):
x = jnp.array(rng.randn(50).astype(onp.float32))
bisect = _projection_simplex_bisect_setup(x)
p = projection.projection_simplex(x)
p2 = _projection_simplex_bisect(x)
p3 = jitted_fun(x)
p2 = _projection_simplex_bisect(bisect, x)
p3 = _projection_simplex_bisect_jitted(bisect, x)
self.assertArraysAllClose(p, p2, atol=1e-4)
self.assertArraysAllClose(p, p3, atol=1e-4)

J = jax.jacrev(projection.projection_simplex)(x)
J2 = jax.jacrev(_projection_simplex_bisect)(x)
J3 = jax.jacrev(jitted_fun)(x)
J2 = jax.jacrev(_projection_simplex_bisect, argnums=1)(bisect, x)
J3 = jax.jacrev(_projection_simplex_bisect_jitted, argnums=1)(bisect, x)
self.assertArraysAllClose(J, J2, atol=1e-5)
self.assertArraysAllClose(J, J3, atol=1e-5)

Expand All @@ -96,7 +101,9 @@ def test_bisect_wrong_upper_bracket(self):
def test_grad_of_value_and_grad(self):
# See https://github.com/google/jaxopt/issues/141

bisect = lambda x: _projection_simplex_bisect(x)[0]
def bisect(x):
b = _projection_simplex_bisect_setup(x)
return _projection_simplex_bisect(b, x)[0]

def bisect_val(x):
val, _ = jax.value_and_grad(bisect)(x)
Expand Down

0 comments on commit ea8e0f1

Please sign in to comment.