-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jaxifying forward transform refactored #128
Conversation
…da fn wrapper around phase_shift function.
Suggested next steps for me (open to feedback):
|
This looks great @sfmig, although I'm still trying to fully understand it. Are all the vmapped |
Actually, ignore my previous comment. @CosmoMatt highlighted to me that vmap returns a function and your mapping over the computations in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really good @sfmig. I've added some comments and suggested changes - a lot of these are with regards to the changes to s2fft.healpix_ffts
which appreciate you said you had only done an initial quick pass at JAXifying, so some of my comments might be things you were already aware of. I didn't test all of my suggested changes as well so it might be that in some cases there are good reasons for sticking with the current implementation if it works and what I've suggested doesn't 😅!.
Quick fixes from code review Co-authored-by: Matt Graham <matthew.m.graham@gmail.com>
quick fixes to healpix_ffts jax functions Co-authored-by: Matt Graham <matthew.m.graham@gmail.com>
…matics/s2fft into jax-fwd-transform-refactored
…ion of p2phi_ring
@sfmig I pulled and ran the forward JAX transform on a GPU machine (with Nvidia a100's), and I think your implementation seems to be doing fine. The real downside at the moment is that, although we split the recursions for each theta and el with a tl;dr: I think this implementation is working as expected (and should be very fast), but we need to focus on accelerating the recursion. |
Hi all,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @sfmig for your work on this and the really helpful summary in your comment - this is looking really good. Having of all of the different JAX implementation options available should be very helpful in trying to do some more extensive benchmarking to figure out the tradeoffs between the differing approaches when running on CPU and GPU. I've added a few suggested changes to remove commented out old code and responding to some points you raised in comments, but this otherwise looks good to merge from my perspective once @CosmoMatt and/or @jasonmcewen have also had a chance to review.
removing some commented out blocks Co-authored-by: Matt Graham <matthew.m.graham@gmail.com>
… jit-compiled version
…d behaviour due to out of bounds indexing
This was a solid effort, just seems that JAX really wasn't playing ball with our original approach. Now we have a working alternative I'm going to close this PR for alpha release. |
A draft PR to update on the JAXified forward transform
Main additions:
Added a JAX version of the forward transform,
_compute_forward_sov_fft_vectorized_jax
, in the refactored transforms module.Added masking of output flm. To match the output
flm
array to the value obtained froms2f.utils.generate_flm
(for the same parameters), I mask with zeros the spurious values inflm
that result from thewigner.turok_jax.compute_slice
operation.Added the 'healpix' sampling option to the JAX forward transform. This involved vmapping the computation of the phase_shift, which in turn required some modifications in the underlying functions
ring_phase_shift_hp
andp2phi_ring
. For now I made_vmappable
versions of these two functions, but I assume we probably want to have just one function in the long term.JIT transformed the JAX forward transform. This required some additional changes in the
s2fft.healpix_ffts
module to make JIT work if the selected sampling scheme is 'healpix'. Specifically, I added jax-ammenable versions of two functions,healpix_fft_jax
andspectral_periodic_extension_jax
. For now I added them to the module as separate functions, and I only did a naive/quick JAXifying. I think they could be JAXified further, will continue to have a look.Added some notebooks to the notebook directory. I don't expect these to be merged with main but I include them now for reference. Mostly exploratory, to check against groundtruth and to time different implementations.
Re performance, as we discussed before the simply vectorised JAX code was not much better (or worse) than the Numpy implementation (
sov_fft_vectorized
), but adding JIT really made a difference. For L=5, spin=2 the Numpy implementation takes around 2-3 ms, while the JAX-JIT one takes around 100-300 µs (just using timeit). I was thinking of having a go in the cluster with a GPU next.I tagged Matt G as reviewer for more detailed comments on the PR, but any feedback or comments are welcome.