Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jaxifying forward transform refactored #128

Closed
wants to merge 65 commits into from

Conversation

sfmig
Copy link
Collaborator

@sfmig sfmig commented Dec 1, 2022

A draft PR to update on the JAXified forward transform

Main additions:

  • Added a JAX version of the forward transform, _compute_forward_sov_fft_vectorized_jax, in the refactored transforms module.

  • Added masking of output flm. To match the output flm array to the value obtained from s2f.utils.generate_flm (for the same parameters), I mask with zeros the spurious values in flm that result from the wigner.turok_jax.compute_slice operation.

  • Added the 'healpix' sampling option to the JAX forward transform. This involved vmapping the computation of the phase_shift, which in turn required some modifications in the underlying functions ring_phase_shift_hp and p2phi_ring. For now I made _vmappable versions of these two functions, but I assume we probably want to have just one function in the long term.

  • JIT transformed the JAX forward transform. This required some additional changes in the s2fft.healpix_ffts module to make JIT work if the selected sampling scheme is 'healpix'. Specifically, I added jax-ammenable versions of two functions, healpix_fft_jax and spectral_periodic_extension_jax. For now I added them to the module as separate functions, and I only did a naive/quick JAXifying. I think they could be JAXified further, will continue to have a look.

  • Added some notebooks to the notebook directory. I don't expect these to be merged with main but I include them now for reference. Mostly exploratory, to check against groundtruth and to time different implementations.

  • Re performance, as we discussed before the simply vectorised JAX code was not much better (or worse) than the Numpy implementation (sov_fft_vectorized), but adding JIT really made a difference. For L=5, spin=2 the Numpy implementation takes around 2-3 ms, while the JAX-JIT one takes around 100-300 µs (just using timeit). I was thinking of having a go in the cluster with a GPU next.

    Sampling scheme Numpy (ms) JAX+JIT (ms)
    mw 3.27 0.32
    mwss 2.97 0.13
    dh 3.05 0.08
    healpix 2.25 0.08

I tagged Matt G as reviewer for more detailed comments on the PR, but any feedback or comments are welcome.

@sfmig
Copy link
Collaborator Author

sfmig commented Dec 1, 2022

Suggested next steps for me (open to feedback):

  • further JAXify the healpix_fft functions, healpix_fft_jax and spectral_periodic_extension_jax
  • refactor duplicate functions (i.e., combine JAX and non-JAX versions)
  • check running the JAX transform on a GPU
  • add tests if required
  • move to JAXifying inverse transform

@jasonmcewen jasonmcewen self-requested a review December 1, 2022 23:25
@jasonmcewen
Copy link
Contributor

jasonmcewen commented Dec 2, 2022

This looks great @sfmig, although I'm still trying to fully understand it. Are all the vmapped dl computed and stored before used to compute flm? If so we want to avoid that since that will require a lot of memory as we go to larger bandlimits L. We want to compute the dl for the given el and t, use it, and then discard it.

@jasonmcewen
Copy link
Contributor

Actually, ignore my previous comment. @CosmoMatt highlighted to me that vmap returns a function and your mapping over the computations in the flm computation.

Copy link
Collaborator

@matt-graham matt-graham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good @sfmig. I've added some comments and suggested changes - a lot of these are with regards to the changes to s2fft.healpix_ffts which appreciate you said you had only done an initial quick pass at JAXifying, so some of my comments might be things you were already aware of. I didn't test all of my suggested changes as well so it might be that in some cases there are good reasons for sticking with the current implementation if it works and what I've suggested doesn't 😅!.

s2fft/samples.py Outdated Show resolved Hide resolved
s2fft/healpix_ffts.py Outdated Show resolved Hide resolved
s2fft/healpix_ffts.py Outdated Show resolved Hide resolved
s2fft/healpix_ffts.py Outdated Show resolved Hide resolved
s2fft/healpix_ffts.py Outdated Show resolved Hide resolved
s2fft/samples.py Outdated Show resolved Hide resolved
s2fft/transform.py Outdated Show resolved Hide resolved
s2fft/transform.py Outdated Show resolved Hide resolved
s2fft/transform.py Outdated Show resolved Hide resolved
s2fft/transform.py Outdated Show resolved Hide resolved
sfmig and others added 7 commits December 2, 2022 16:46
Quick fixes from code review

Co-authored-by: Matt Graham <matthew.m.graham@gmail.com>
quick fixes to healpix_ffts jax functions

Co-authored-by: Matt Graham <matthew.m.graham@gmail.com>
…matics/s2fft into jax-fwd-transform-refactored
@CosmoMatt
Copy link
Collaborator

@sfmig I pulled and ran the forward JAX transform on a GPU machine (with Nvidia a100's), and I think your implementation seems to be doing fine. The real downside at the moment is that, although we split the recursions for each theta and el with a vmap, the computation of only a single one of these parallel threads (i.e. a single Turok recursion) on the GPU is 10 times slower than the entire pyssht transform.

tl;dr: I think this implementation is working as expected (and should be very fast), but we need to focus on accelerating the recursion.

@sfmig
Copy link
Collaborator Author

sfmig commented Jan 17, 2023

Hi all,
Below a summary of the latest additions to this PR, which I think is now ready for review.

  • Different implementations and tradeoffs: the JAX implementations of the forward transform are named based on the approach used to scan over el and theta:

    method scan over el scan over theta flm computation
    jax_vmap_double vmap vmap summing along the last axis of a 3d array
    jax_vmap_scan vmap lax.scan cumulative sum on a 2d array
    jax_vmap_loop vmap Python loop cumulative sum on a 2d array
    jax_map_double map map summing along the last axis of a 3d array
    jax_map_scan map lax.scan cumulative sum on a 2d array

    Some comments:

    • jax_vmap_double was the first approach we explored.vmap is in many aspects optimal for GPU performance, but here we run into OOM issues at moderate L values (~ 1000) as noted by @CosmoMatt .
    • jax_vmap_scan and jax_vmap_loop should avoid these OOM issues because in them we compute flm as a cumulative sum (rather than as a 3d array, which we collapse into 2d by summing along its last axis).
    • Running the jax_vmap_ approaches on the CPU in my laptop at low L values, I didn't find significant differences in execution time, apart from jax_vmap_loop being slightly slower. However this may be different on the GPU.
    • The jax_map_ implementations were introduced to solve the performance hit observed on the CPU when vmapping functions with nested conditionals (related to this issue). We found an order of magnitude improvement of map over vmap on the CPU, but later the vmap implementation performed better on the GPU.
    • All methods include support for real signals, for all sampling schemes except forhealpix (as is the case in the main branch right now).
    • I added a tag v0_pad_concat for a commit that includes two additional implementations for jax_vmap_scan and jax_vmap_loop, in which I use padding and array concatenation rather than in-place .at() operators to build the full flm array. I didn't find significant differences in performance (timewise) in my laptop, so I decided to only keep the versions using .at() operators for simplicity.
  • Refactoring approach:

    • I followed the forward function definition to factor out common bits of the JAX implementations.
    • For the supporting functions in samples.py and healpix_ffts.py, I was able to define common functions for JAX and non-JAX implementations by adding the numpy_module as input, as suggested by @matt-graham.
    • However, for the p2phi_ring function I included a JAX-exclusive implementation based on lax.cond's that might be interesting going forward. I didn't find significant differences with the JAX agnostic implementation on my CPU but it may be different on a GPU. Sadly both approaches seem to compute all possible branches (see here).
  • Testing:

    • I added the JAX implementations in the testing of the forward transform. I was able to test all sampling schemes locally, but on Github actions I think I hit the runner's limitations with healpix sampling on the JAX implementations. I wasn't sure how to fix that so for now I have commented those tests out . Any feedback on that is welcome!

@sfmig sfmig marked this pull request as ready for review January 17, 2023 12:44
@sfmig sfmig requested a review from matt-graham January 17, 2023 12:45
s2fft/healpix_ffts.py Outdated Show resolved Hide resolved
s2fft/healpix_ffts.py Outdated Show resolved Hide resolved
s2fft/healpix_ffts.py Outdated Show resolved Hide resolved
s2fft/samples.py Outdated Show resolved Hide resolved
s2fft/transform.py Outdated Show resolved Hide resolved
s2fft/transform.py Outdated Show resolved Hide resolved
s2fft/transform.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@matt-graham matt-graham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sfmig for your work on this and the really helpful summary in your comment - this is looking really good. Having of all of the different JAX implementation options available should be very helpful in trying to do some more extensive benchmarking to figure out the tradeoffs between the differing approaches when running on CPU and GPU. I've added a few suggested changes to remove commented out old code and responding to some points you raised in comments, but this otherwise looks good to merge from my perspective once @CosmoMatt and/or @jasonmcewen have also had a chance to review.

@sfmig sfmig requested a review from matt-graham January 30, 2023 12:07
@CosmoMatt
Copy link
Collaborator

This was a solid effort, just seems that JAX really wasn't playing ball with our original approach. Now we have a working alternative I'm going to close this PR for alpha release.

@CosmoMatt CosmoMatt closed this Feb 14, 2023
@CosmoMatt CosmoMatt deleted the jax-fwd-transform-refactored branch February 14, 2023 15:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants