JAX v0.4.37
This is a patch release of jax 0.4.36. Only "jax" was released at this version.
- Bug fixes
- Fixed a bug where
jit
would error if an argument was namedf
(#25329). - Fix a bug that will throw
index out of range
error in
jax.lax.while_loop
if the user registers pytree node class with
different aux data for the flatten and flatten_with_path. - Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.
- Fixed a bug where