Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jaxifying forward transform refactored #128

Closed
wants to merge 65 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
5786bc6
added forward transform vectorised with jax
sfmig Nov 28, 2022
71d55b7
small changes to notebook
sfmig Nov 28, 2022
418056c
added jax forward transform, pending healpix.
sfmig Nov 29, 2022
ebecb0b
exploring jit
sfmig Nov 29, 2022
f0e38e5
checking phase shift for healpix
sfmig Nov 29, 2022
f939e82
checking phase shift for healpix in fwd transform
sfmig Nov 29, 2022
3ac06ac
checking error when vmapping phase shift
sfmig Nov 29, 2022
f473a52
added healpix sampling to JAXed fwd transform
sfmig Nov 30, 2022
4cbdf8a
moved notebooks to separate dir
sfmig Nov 30, 2022
7b161a6
added jit to jax forward transform
sfmig Nov 30, 2022
727255b
changes to make forward transform with healpix jit-able. removed lamb…
sfmig Dec 1, 2022
aa724d0
Apply suggestions from code review
sfmig Dec 2, 2022
eb2a44d
Apply suggestions from code review
sfmig Dec 2, 2022
093dbe0
minor changes to healpix_fft_jax
sfmig Dec 2, 2022
e6edf60
Merge branch 'jax-fwd-transform-refactored' of github.com:astro-infor…
sfmig Dec 2, 2022
85ff634
replaced lax slicing with regular slicing in healpix_fft_jax
sfmig Dec 7, 2022
9b36268
removed explicit conversion to DeviceArrays
sfmig Dec 7, 2022
c5873cc
replaced jnp.where() approach by nested lax.conds() in vmappable vers…
sfmig Dec 7, 2022
68ebf42
further jaxifying healpix_fft_jax and spectral_periodic_extension_jax…
sfmig Dec 13, 2022
f694ab6
in forward transform, replaced vmap approach with lax.map as suggeste…
sfmig Dec 13, 2022
824ec8e
refactored spectral_periodic_extension_jax following Matt G's suggestion
sfmig Dec 14, 2022
637ec2d
removed commented block
sfmig Dec 14, 2022
007c4dd
separated map and vmap implementations
sfmig Dec 15, 2022
cee9de6
three approaches to JAXifying healpix_fft. The lax.scan one (2) is st…
sfmig Dec 16, 2022
804bdfb
a notebook to compare healpix_fft JAX approaches
sfmig Dec 16, 2022
0908161
a vmappable way to compute the number of phi samples for healpix (unu…
sfmig Dec 16, 2022
d6ea835
removed some comments
sfmig Dec 16, 2022
ff0e845
keeping Healpix FFT JAX implementation using jax.numpy/numpy stack only
sfmig Dec 19, 2022
1dee88b
some notebooks to check against groundtruth
sfmig Dec 19, 2022
7f04e06
Merge branch 'main' of github.com:astro-informatics/s2fft into jax-fw…
sfmig Dec 19, 2022
e4f09a3
added numpy module to doc string in healpix_fft_jax
sfmig Dec 19, 2022
f76c81b
first attempt at adding reality to _compute_forward_sov_fft_vectorize…
sfmig Dec 19, 2022
59527f3
some work trying to add reality bits to turok_jax (in progress)
sfmig Dec 20, 2022
637e462
added reality option to turok_jax and forward transform
sfmig Jan 11, 2023
80b9204
changes dl computation to vmap along el dimension only to prevent oom…
sfmig Jan 11, 2023
833177c
replaced manual loop across theta in dl computation with a lax.scan, …
sfmig Jan 11, 2023
97fafc1
added implementation with vmap and manual loop over theta and refacto…
sfmig Jan 11, 2023
5480154
black formatting
sfmig Jan 11, 2023
49c3a0d
fixed phase shift in jax_vmap_loop, and added versions without .at in…
sfmig Jan 12, 2023
3d4811a
added double map implementation for dl with reality
sfmig Jan 12, 2023
4a675ac
added reality to map+scan approach
sfmig Jan 13, 2023
acc19cc
refactoring bits
sfmig Jan 13, 2023
e3afe81
fixed phase shift for healpix case in jax_vmap_loop and jax_vmap_loop_0
sfmig Jan 13, 2023
b0fdfa4
notebook to check jax implementations of fwd transform vs groundtruth
sfmig Jan 13, 2023
82b4dc4
factored out common bits of jax implementations under _compute_forwar…
sfmig Jan 13, 2023
755a575
refactored jax implementations further (same variable names and axes …
sfmig Jan 16, 2023
745ed4f
refactored notebook to check against ground truth
sfmig Jan 16, 2023
bcaa4e2
refactored notebook to check against ground truth
sfmig Jan 16, 2023
6455b6e
removed previous implementations from jax list in _forward
sfmig Jan 16, 2023
3627513
added tests for jax implementations
sfmig Jan 16, 2023
8385bd6
cosmetic changes
sfmig Jan 16, 2023
4bb85bb
added a test for JAX implementations with healpix sampling
sfmig Jan 16, 2023
57a4855
refactored supporting functions in samples and healpix_ffts
sfmig Jan 16, 2023
ee438fe
removed old notebooks. kept latest one comparing to groundtruth.
sfmig Jan 16, 2023
81c7a60
changed np.pi to the selected module (jnp or np)
sfmig Jan 16, 2023
84907ae
removed healpix_jax test using vectorized approach
sfmig Jan 16, 2023
d572797
removed jax methods from healpix tests
sfmig Jan 16, 2023
a3c4876
Apply suggestions from code review
sfmig Jan 23, 2023
8f2551e
added numpy-only original implementation of p2phi_ring function
sfmig Jan 23, 2023
77d1cea
small refactoring in the symmetry loop for readability
sfmig Jan 23, 2023
2c25b54
commented warning for invalid spin value since it will be lost in the…
sfmig Jan 23, 2023
d882d41
add list of supported JAX methods in the docstring for the method arg…
sfmig Jan 30, 2023
790f9f0
remove comment on alternative padding approach
sfmig Jan 30, 2023
b4cfde0
removed alternative phase_shift computation inside accumulate function
sfmig Jan 30, 2023
d72ab48
changed phase shift to be computed inside the loop to avoid unexpecte…
sfmig Jan 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading