Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for compatibility with JAX version 0.4.28. #719

Merged
merged 1 commit into from
May 12, 2024
Merged

Conversation

patrick-kidger
Copy link
Owner

  • callback inputs are now Arrays rather than arrays.
  • emit_python_callback now uses a keyword argument for what was once a positional argument.

Other nice-to-haves whilst I'm here:

  • Removed filter_jit passing function default arguments across the JIT boundary. In particular this is incompatible with filter_jit(filter_vmap(eqx.nn.MLP(...))), as then it tries to pass the MLP.__call__(..., *, key=...) default argument through, but filter_vmap does not allow keyword arguments.
  • Added some stacklevels to some warnings.

- callback inputs are now `Array`s rather than `array`s.
- `emit_python_callback` now uses a keyword argument for what was once a positional argument.

Other nice-to-haves whilst I'm here:

- Removed `filter_jit` passing function default arguments across the JIT boundary. In particular this is incompatible with `filter_jit(filter_vmap(eqx.nn.MLP(...)))`, as then it tries to pass the `MLP.__call__(..., *, key=...)` default argument through, but `filter_vmap` does not allow keyword arguments.
- Added some `stacklevel`s to some warnings.
@patrick-kidger patrick-kidger merged commit 7de34a2 into main May 12, 2024
2 checks passed
@patrick-kidger patrick-kidger deleted the 0.4.28-fixes branch May 12, 2024 12:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant