Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
avoid closing over dynamic jax tracers in the bisection solver
Internally in jaxopt, we (should) try to maintain that the parameters to a solver class are "static" from jax's point of view. One reason for this is that class attributes might be read by any of the class' methods, including `run`. Meanwhile a bound `run` method serves as the solver function, which is passed through jaxopt's core `custom_root` mechanism in order to set it up with an implicit-diff-based custom VJP. Currently, that `custom_root` mechanism assumes that the solver function it receives has, in its closure, no arrays that are involved in any of jax's differentiation or staging. Re-stated using jax-internal jargon: `custom_root` assumes that the solver function it receives does not have tracers in its closure. But: a bound Python method (e.g. `o.run`) carries its bound instance (e.g. `o`) in its closure. The code in `bisection_test.py` did not conform to this requirement that all class attributes are static (in the jax transformation sense). Specifically, it constructed a `Bisection` instance, within a jitted function, given parameters (`lower` and `upper`) that depend on inputs to the jitted function. This change fixes that by hoisting the construction of this `Bisection` out from the jitted function (and marking it a static argument). Doing this fixes a jax "tracer leak" error raised in the jaxopt CI recently. This was not an issue until jax released version 0.4.4, for the rather technical reason that jax changed its `jit` implementation such that it eagerly stages out its function argument. This in turn led jax to encounter "jit tracers" (corresponding to `Bisection.{lower,upper}`) within the closure of a solver function (`Bisection.run`) in the course of custom-differentiating the solver function.
- Loading branch information