Skip to content

Commit

Permalink
avoid closing over dynamic jax tracers in the bisection solver
Browse files Browse the repository at this point in the history
Internally in jaxopt, we (should) try to maintain that the parameters
to a solver class are "static" from jax's point of view.

One reason for this is that class attributes might be read by any of
the class' methods, including `run`. Meanwhile a bound `run` method
serves as the solver function, which is passed through jaxopt's core
`custom_root` mechanism in order to set it up with an
implicit-diff-based custom VJP. Currently, that `custom_root`
mechanism assumes that the solver function it receives has, in its
closure, no arrays that are involved in any of jax's differentiation
or staging. Re-stated using jax-internal jargon: `custom_root` assumes
that the solver function it receives does not have tracers in its
closure. But: a bound Python method (e.g. `o.run`) carries its bound
instance (e.g. `o`) in its closure.

The code in `bisection_test.py` did not conform to this requirement
that all class attributes are static (in the jax transformation
sense). Specifically, it constructed a `Bisection` instance, within a
jitted function, given parameters (`lower` and `upper`) that depend on
inputs to the jitted function. This change fixes that by hoisting the
construction of this `Bisection` out from the jitted function (and
marking it a static argument).

Doing this fixes a jax "tracer leak" error raised in the jaxopt CI
recently. This was not an issue until jax released version 0.4.4, for
the rather technical reason that jax changed its `jit` implementation
such that it eagerly stages out its function argument. This in turn
led jax to encounter "jit tracers" (corresponding to
`Bisection.{lower,upper}`) within the closure of a solver function
(`Bisection.run`) in the course of custom-differentiating the solver
function.
  • Loading branch information
froystig committed Mar 15, 2023
1 parent cb6ed9a commit e196ece
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions tests/bisection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,15 @@ def _optimality_fun_proj_simplex(tau, x, s):
return jnp.sum(jnp.maximum(x - tau, 0)) - s


def _threshold_proj_simplex(x, s):
def _threshold_proj_simplex(bisect, x, s=1.0):
return bisect.run(None, x, s).params


def _projection_simplex_bisect(bisect, x, s=1.0):
return jnp.maximum(x - _threshold_proj_simplex(bisect, x, s), 0)


def _projection_simplex_bisect_setup(x, s=1.0):
# tau = max(x) => tau >= x_i for all i
# => x_i - tau <= 0 for all i
# => maximum(x_i - tau, 0) = 0 for all i
Expand All @@ -45,33 +53,30 @@ def _threshold_proj_simplex(x, s):
# where tau = tau' - s / len(x)
lower = jax.lax.stop_gradient(jnp.min(x)) - s / len(x)

bisect = Bisection(optimality_fun=_optimality_fun_proj_simplex,
lower=lower, upper=upper, check_bracket=False)
return bisect.run(None, x, s).params


def _projection_simplex_bisect(x, s=1.0):
return jnp.maximum(x - _threshold_proj_simplex(x, s), 0)
return Bisection(optimality_fun=_optimality_fun_proj_simplex,
lower=lower, upper=upper, check_bracket=False)


class BisectionTest(test_util.JaxoptTestCase):

def test_bisect(self):
rng = onp.random.RandomState(0)

jitted_fun = jax.jit(_projection_simplex_bisect)
_projection_simplex_bisect_jitted = jax.jit(
_projection_simplex_bisect, static_argnums=0)

for _ in range(10):
x = jnp.array(rng.randn(50).astype(onp.float32))
bisect = _projection_simplex_bisect_setup(x)
p = projection.projection_simplex(x)
p2 = _projection_simplex_bisect(x)
p3 = jitted_fun(x)
p2 = _projection_simplex_bisect(bisect, x)
p3 = _projection_simplex_bisect_jitted(bisect, x)
self.assertArraysAllClose(p, p2, atol=1e-4)
self.assertArraysAllClose(p, p3, atol=1e-4)

J = jax.jacrev(projection.projection_simplex)(x)
J2 = jax.jacrev(_projection_simplex_bisect)(x)
J3 = jax.jacrev(jitted_fun)(x)
J2 = jax.jacrev(_projection_simplex_bisect, argnums=1)(bisect, x)
J3 = jax.jacrev(_projection_simplex_bisect_jitted, argnums=1)(bisect, x)
self.assertArraysAllClose(J, J2, atol=1e-5)
self.assertArraysAllClose(J, J3, atol=1e-5)

Expand All @@ -96,7 +101,9 @@ def test_bisect_wrong_upper_bracket(self):
def test_grad_of_value_and_grad(self):
# See https://github.com/google/jaxopt/issues/141

bisect = lambda x: _projection_simplex_bisect(x)[0]
def bisect(x):
b = _projection_simplex_bisect_setup(x)
return _projection_simplex_bisect(b, x)[0]

def bisect_val(x):
val, _ = jax.value_and_grad(bisect)(x)
Expand Down

0 comments on commit e196ece

Please sign in to comment.