Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flax and haiku contrib tests fail under jax 0.4.36. #1934

Closed
tillahoffmann opened this issue Dec 6, 2024 · 4 comments · Fixed by #1935
Closed

flax and haiku contrib tests fail under jax 0.4.36. #1934

tillahoffmann opened this issue Dec 6, 2024 · 4 comments · Fixed by #1935
Labels
bug Something isn't working

Comments

@tillahoffmann
Copy link
Contributor

tillahoffmann commented Dec 6, 2024

flax and haiku contrib tests fail under jax 0.4.36. See here for a failed test run from #1932. Pytest output below.

FAILED test/contrib/test_module.py::test_haiku_state_dropout_smoke[True-True] - jax.errors.UnexpectedTracerError: An UnexpectedTracerError was raised while inside a Haiku transformed function (see error above).
Hint: are you using a JAX transform or JAX control-flow function (jax.vmap/jax.lax.scan/...) inside a Haiku transform? You might want to use the Haiku version of the transform instead (hk.vmap/hk.scan/...).
See https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html on why you can't use JAX transforms inside a Haiku module.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
FAILED test/contrib/test_module.py::test_haiku_state_dropout_smoke[True-False] - jax.errors.UnexpectedTracerError: An UnexpectedTracerError was raised while inside a Haiku transformed function (see error above).
Hint: are you using a JAX transform or JAX control-flow function (jax.vmap/jax.lax.scan/...) inside a Haiku transform? You might want to use the Haiku version of the transform instead (hk.vmap/hk.scan/...).
See https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html on why you can't use JAX transforms inside a Haiku module.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
FAILED test/contrib/test_module.py::test_flax_state_dropout_smoke[True-True] - jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[10] wrapped in a JVPTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
FAILED test/contrib/test_module.py::test_flax_state_dropout_smoke[True-False] - jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[10] wrapped in a JVPTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
@fehiepsi fehiepsi added the bug Something isn't working label Dec 6, 2024
@fehiepsi
Copy link
Member

fehiepsi commented Dec 7, 2024

From what I know, this issue comes from chex (some functionality is not compatible with jax 0.4.36). A new release will be available soon.

@fehiepsi fehiepsi added the jax This issue is specific to JAX label Dec 7, 2024
@juanitorduz
Copy link
Contributor

juanitorduz commented Dec 11, 2024

I just tried it with the new release

> pip freeze | grep chex
chex==0.1.88

and I keep seeing the errors locally :(

@fehiepsi fehiepsi removed the jax This issue is specific to JAX label Dec 12, 2024
@fehiepsi
Copy link
Member

Turns out that this is a bug which was not detected earlier. Sorry for the confusion.

@juanitorduz
Copy link
Contributor

Thanks @fehiepsi 🙌

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants