Skip to content

Conversation

@ealt
Copy link
Collaborator

@ealt ealt commented Aug 5, 2025

Refactor full_equation to be compatible with JAX jit and vmap for improved performance and parallelization.

The original full_equation contained JAX vectorization blockers such as data-dependent while loops, dynamic array indexing, and mutable state. This refactoring replaces these with JAX primitives like jax.lax.scan, jax.lax.fori_loop, and jax.lax.select to enable efficient compilation and vectorization.


Open in Cursor Open in Web

Co-authored-by: ericallenalt <ericallenalt@gmail.com>
@cursor
Copy link

cursor bot commented Aug 5, 2025

Cursor Agent can help with this pull request. Just @cursor in comments and I'll start working on changes in this branch.
Learn more about Cursor Agents

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants