Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_TrivialClosureConvert mis-specifies dynamic attributes #768

Closed
jakevdp opened this issue Jun 21, 2024 · 3 comments · Fixed by #769
Closed

_TrivialClosureConvert mis-specifies dynamic attributes #768

jakevdp opened this issue Jun 21, 2024 · 3 comments · Fixed by #769

Comments

@jakevdp
Copy link
Contributor

jakevdp commented Jun 21, 2024

I believe this should say static=False:

in_dynamic_struct: _FlatPyTree[jax.ShapeDtypeStruct] = field(static=True)

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jun 21, 2024

This is intentional! Given some pytree-of-arrays, then we calleval_shape to transform it into a pytree-of-shapedtypestructs. We compare against this later, to be sure we are later called with compatible arguments:

equinox/equinox/_ad.py

Lines 543 to 549 in 5a5bf28

in_dynamic_struct = jax.eval_shape(lambda: in_dynamic)
# `is` because `tree_equal` may return a tracer
if tree_equal(in_dynamic_struct, self_in_dynamic_struct) is not True:
raise ValueError(
"Closure-converted function called with different dynamic arguments to "
"the example arguments provided."
)

In this case we actually store the pytree-of-shapedtypestructs as a static (tuple of shapedtypestructs, treedef). This allows the user to pass this object around safely -- across JIT boundaries, eval_shape boundaries etc. -- without those shapedtypestructs erroneously getting turned back into dynamic objects.

I appreciate the heads-up though -- I can see how the 'dynamic' in the name suggests otherwise!

@jakevdp
Copy link
Contributor Author

jakevdp commented Jun 22, 2024

Ah, my mistake, I see what you mean. I think I misdiagnosed this, but there's some real issue related to tracer hashing (deprecated in jax-ml/jax#21863) and related to ClosureConvert. Here's a minimal repro via diffrax:

import jax
import equinox
import diffrax

print(f"{jax.__version__=}\n{equinox.__version__=}\n{diffrax.__version__=}")
# jax.__version__='0.4.30'
# equinox.__version__='0.11.4'
# diffrax.__version__='0.5.1'
# Accelerate the deprecation of tracer hashability, so it raises a TypeError rather than a warning.
# This will be the behavior in a future JAX release.
from jax._src import deprecations
deprecations.accelerate('tracer-hash')
jax.jit(hash)(1)
# TypeError: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController

vector_field = lambda t, y, args: -y
term = ODETerm(vector_field)
solver = Dopri5()
saveat = SaveAt(ts=[0., 1., 2., 3.])
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)

sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
                  stepsize_controller=stepsize_controller)
# TypeError: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>

Something in this call stack ends up computing the hash of a tracer... I thought it may have been in _TrivialClosureConvert, but it looks like it's something else (but closely related). It seems to be due to passing a _ClosureConvert object to while_loop. Any thoughts on this?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jun 22, 2024

Ah, I see! Haha, looks like we were hashing tracers all along and never even noticed. This should be fixed in #769.


For the record -- i.e. future me, if I ever need this -- the problem was that we generate a _ClosureConvert on this line:

cond_fun_ = filter_closure_convert(cond_fun_, init_val_)

which is a callable that includes some tracers in its pytree structure. This ultimately hits a lax.while_loop(cond_fun, ...). Now lax.while_loop hashes its function arguments to try and record their jaxprs, as in jax-ml/jax#13554. This previously worked because tracers were hashable, but when tracers become unhashable then it explodes.

The fix is just to add a dummy lambda val: cond_fun_(val) wrapper.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants