You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
flax and haiku contrib tests fail under jax 0.4.36. See here for a failed test run from #1932. Pytest output below.
FAILED test/contrib/test_module.py::test_haiku_state_dropout_smoke[True-True] - jax.errors.UnexpectedTracerError: An UnexpectedTracerError was raised while inside a Haiku transformed function (see error above).
Hint: are you using a JAX transform or JAX control-flow function (jax.vmap/jax.lax.scan/...) inside a Haiku transform? You might want to use the Haiku version of the transform instead (hk.vmap/hk.scan/...).
See https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html on why you can't use JAX transforms inside a Haiku module.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
FAILED test/contrib/test_module.py::test_haiku_state_dropout_smoke[True-False] - jax.errors.UnexpectedTracerError: An UnexpectedTracerError was raised while inside a Haiku transformed function (see error above).
Hint: are you using a JAX transform or JAX control-flow function (jax.vmap/jax.lax.scan/...) inside a Haiku transform? You might want to use the Haiku version of the transform instead (hk.vmap/hk.scan/...).
See https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html on why you can't use JAX transforms inside a Haiku module.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
FAILED test/contrib/test_module.py::test_flax_state_dropout_smoke[True-True] - jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[10] wrapped in a JVPTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
FAILED test/contrib/test_module.py::test_flax_state_dropout_smoke[True-False] - jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[10] wrapped in a JVPTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
The text was updated successfully, but these errors were encountered:
flax and haiku contrib tests fail under jax 0.4.36. See here for a failed test run from #1932. Pytest output below.
The text was updated successfully, but these errors were encountered: