Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sparse jacobian solve #545

Open
martinjrobins opened this issue Dec 17, 2024 · 4 comments
Open

sparse jacobian solve #545

martinjrobins opened this issue Dec 17, 2024 · 4 comments
Labels
question User queries

Comments

@martinjrobins
Copy link

Can the implicit solvers in diffrax use a sparse matrix solve for the jacobian? I'm putting together a benchmark with a few different ode solvers, including diffrax, and the problem in question is stiff (its based on the robertson ODE problem from your examples) and has a small block-diagonal jacobian structure (each block is 3x3 and the total matrix size goes up to 10,000, so very sparse). Diffrax is slowing down significantly at larger matrix sizes and my assumption is that you are using a dense linear solver? Is there a way of swapping this out for a sparse solver? My understanding is that sparse matrix support is still rather experimental for JAX.

You can see the benchmark and results here: https://martinjrobins.github.io/diffsol/benchmarks/python.html. Please feel free to let me know if I'm not using diffrax correctly. I've not used it in anger before, so its entirely possible I'm doing something stupid.

@patrick-kidger
Copy link
Owner

Yup, both JAX and Diffrax are optimized around dense matrices. We do have some iterative linear solvers you can use (e.g. lineax.CG), although you'd need a sparse representation of the Jacobian to really take advantage of this.

Taking a quick look at the benchmark you have:

  • if you have multiple groups with the same problem structure, then this would be most efficiently handled using a jax.vmap over the groups. Basically, your sparsity is of a particular form that is amenable to SIMD parallelism!
  • I definitely wouldn't try to compare BDF and Kvaerno5, these are really two very different solvers.

That aside I'm curious to see a little more about the library you're writing here. We learnt a lot of lessons in writing Diffrax, many of them new to Diffrax, e.g.:

  • How to handle ODEs-vs-SDEs in a unified way;
  • How to cleanly implement events including backprop through events;
  • The 'term' system that allows for expressing many kinds of ODE in the same language (whether they're general, symplectic, have IMEX-amenable structure, ...)
  • etc!

So if you're tackling the problem of writing an equivalent library in a new ecosystem then I'd be happy to offer some thoughts on how to ensure it builds on prior art both in the JAX ecosystem and (I am also reasonably familiar with) the Julia ecosystem.

I'm also curious if your ambitions include GPU support and reverse-mode (or higher-order) autodiff? (Which so far as I can see aren't there atm?) I think these are becoming table stakes for new numerical software, and history has proven they are hard to retrofit on to older code that was not originally written with it in mind. (Supporting these use cases are a big part of why Diffrax exists!)

@patrick-kidger patrick-kidger added the question User queries label Dec 17, 2024
@martinjrobins
Copy link
Author

Thanks @patrick-kidger. Good call on the jax.vmap, I'll give that a try. And a fair point on the differences between BDF and Kvaerno5, do you have any plans to implement BDF?. I do have an ESDIRK solver in diffsol however, so I'll use this for the comparison

For the iterative solvers in lineax: presumably it just needs the action of the jacobian rather than a full sparse respresentation? and JAX should be able to get this for me via jvp? Once I've got iterative solvers into diffsol this would be a nice comparison to run.

Generally it would be very useful to get your thoughts on the current state-of-the art for ode solvers and where you see the interesting areas for growth. It would be great to organise a call sometime in the new year if you are amenable.

GPU support is definitely on my roadmap, diffsol is designed to be generic over vector, matrix and linear solver types, so I can swap in and out different linear algebra libraries. At the moment I'm hampered by the fact that there are no decent GPU-based linear algebra libraries in rust, but I'm also considering wrapping a C library as a temporary solution until the rust ecosystem matures.

I've implemented the solution of the continuous adjoint equations, with checkpointing. Its not really mentioned in the docs yet because I'm still playing around with the API. I don't do any backprop through the solver however. There is currently a PR into nightly rust that will add autodiff via Enzyme (rust-lang/rust#124509), and once this is done I'm going to try and implement this one.

I had a read through your Term system, which was really interesting, I've done something similar (in purpose, the structure is a bit different) using a system of operator traits (nonlinear, linear etc), and then defining the ode equations as a set of operators. Different solvers can place bounds on these types that define which form of equations they are able to solve, so users get compile-time errors if their equations aren't suitable for the solver used. The way that you've incorporated SDEs into the Term system is great, I'd have to have a think how this could fit into diffsol. Do you see much usage of SDEs with diffrax (compared with deterministic equations)?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Dec 18, 2024

I'd be very happy to take a PR on BDF! No plans to implement it myself right now.

On lineax and sparse solves -- haha I think you're right. This should 'just work' at the moment simply by passing the appropriate linear solve to the differential equation solver. (Since under-the-hood it will create a lineax.JacobianLinearOperator and will not materialise unless it has to.)

As for the current SOTA -- very happy to have a call. I can see you're also in Oxford, and I'll be passing through London early in the new year if by any chance you're in town then too.

When it comes to backpropagation, I'd strongly discourage using the continuous adjoint. This is an almost universally bad idea with essentially no redeeming qualities. 😄 The best option is almost always discretise-then-optimise with recursive checkpointing. (See Equinox's checkpointed while loop.) In the future this may be supplanted by algebraically reversible differential equation solvers but they're still an open research topic. I'll be really curious to see what autodiff looks like in Rust at the language level, I've mostly interacted with framework-level autodiff (PyTorch, JAX) and had bad experiences with language-level autodiff before (Julia).

As for SDEs yup, I get quite a lot of users of these. Diffrax was actually originally a research project to see if we could write out the numerics for ODEs+SDEs in a single unified way (and we can), so they've been in since the start!

@martinjrobins
Copy link
Author

while I won't be able to do a BDF PR to diffrax myself, I could get someone else started as I've implemented BDF in jax before here. Its not the greatest code to get an understanding of the algorithm due to the hoops that jax makes you jump through, the scipy implementation is easier to follow. @BradyPlanden has been doing some recent improvements to the jax bdf code in pybamm, so I'll label him on this in case he is interested in contributing to diffrax

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants