Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX 0.4.36 Failing #532

Closed
lockwo opened this issue Dec 6, 2024 · 2 comments
Closed

JAX 0.4.36 Failing #532

lockwo opened this issue Dec 6, 2024 · 2 comments

Comments

@lockwo
Copy link
Contributor

lockwo commented Dec 6, 2024

With the latest release of jax (1 hour ago) the latest version of diffrax fails. #527 looks like it fixes the problem (at least the one I encountered), but it would be great to get an official release (or disallow this version of jax).

@lockwo
Copy link
Contributor Author

lockwo commented Dec 6, 2024

I stand corrected, https://github.com/patrick-kidger/diffrax/blob/dev/test/test_adjoint.py#L231 seems to also be failing

test_adjoint.py:231:44 - error: "CustomVJPException" is not a known member of module "jax.interpreters.ad" 

patrick-kidger added a commit to patrick-kidger/equinox that referenced this issue Dec 7, 2024
See:

- jax-ml/jax#25289
- patrick-kidger/diffrax#532

The problem was that batching has now become a dynamic trace, and our batching rules were not set up to handle the case that every batch axis is `not_mapped`.
patrick-kidger added a commit to patrick-kidger/equinox that referenced this issue Dec 7, 2024
See:

- jax-ml/jax#25289
- patrick-kidger/diffrax#532

The problem was that batching has now become a dynamic trace, and our batching rules were not set up to handle the case that every batch axis is `not_mapped`.
patrick-kidger added a commit to patrick-kidger/equinox that referenced this issue Dec 7, 2024
See:

- jax-ml/jax#25289
- patrick-kidger/diffrax#532

The problem was that batching has now become a dynamic trace, and our batching rules were not set up to handle the case that every batch axis is `not_mapped`.
patrick-kidger added a commit to patrick-kidger/equinox that referenced this issue Dec 8, 2024
See:

- jax-ml/jax#25289
- patrick-kidger/diffrax#532

The problem was that batching has now become a dynamic trace, and our batching rules were not set up to handle the case that every batch axis is `not_mapped`.
@patrick-kidger
Copy link
Owner

Closing as resolved with the 0.6.1 release of Diffrax!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants