-
-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
_TrivialClosureConvert mis-specifies dynamic attributes #768
Comments
This is intentional! Given some pytree-of-arrays, then we call Lines 543 to 549 in 5a5bf28
In this case we actually store the pytree-of-shapedtypestructs as a static I appreciate the heads-up though -- I can see how the 'dynamic' in the name suggests otherwise! |
Ah, my mistake, I see what you mean. I think I misdiagnosed this, but there's some real issue related to tracer hashing (deprecated in jax-ml/jax#21863) and related to import jax
import equinox
import diffrax
print(f"{jax.__version__=}\n{equinox.__version__=}\n{diffrax.__version__=}")
# jax.__version__='0.4.30'
# equinox.__version__='0.11.4'
# diffrax.__version__='0.5.1' # Accelerate the deprecation of tracer hashability, so it raises a TypeError rather than a warning.
# This will be the behavior in a future JAX release.
from jax._src import deprecations
deprecations.accelerate('tracer-hash')
jax.jit(hash)(1)
# TypeError: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController
vector_field = lambda t, y, args: -y
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
stepsize_controller=stepsize_controller)
# TypeError: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'> Something in this call stack ends up computing the hash of a tracer... I thought it may have been in |
Ah, I see! Haha, looks like we were hashing tracers all along and never even noticed. This should be fixed in #769. For the record -- i.e. future me, if I ever need this -- the problem was that we generate a
which is a callable that includes some tracers in its pytree structure. This ultimately hits a The fix is just to add a dummy |
I believe this should say
static=False
:equinox/equinox/_ad.py
Line 560 in 5a5bf28
The text was updated successfully, but these errors were encountered: