Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jaxifying forward transform refactored #128

Closed
wants to merge 65 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
5786bc6
added forward transform vectorised with jax
sfmig Nov 28, 2022
71d55b7
small changes to notebook
sfmig Nov 28, 2022
418056c
added jax forward transform, pending healpix.
sfmig Nov 29, 2022
ebecb0b
exploring jit
sfmig Nov 29, 2022
f0e38e5
checking phase shift for healpix
sfmig Nov 29, 2022
f939e82
checking phase shift for healpix in fwd transform
sfmig Nov 29, 2022
3ac06ac
checking error when vmapping phase shift
sfmig Nov 29, 2022
f473a52
added healpix sampling to JAXed fwd transform
sfmig Nov 30, 2022
4cbdf8a
moved notebooks to separate dir
sfmig Nov 30, 2022
7b161a6
added jit to jax forward transform
sfmig Nov 30, 2022
727255b
changes to make forward transform with healpix jit-able. removed lamb…
sfmig Dec 1, 2022
aa724d0
Apply suggestions from code review
sfmig Dec 2, 2022
eb2a44d
Apply suggestions from code review
sfmig Dec 2, 2022
093dbe0
minor changes to healpix_fft_jax
sfmig Dec 2, 2022
e6edf60
Merge branch 'jax-fwd-transform-refactored' of github.com:astro-infor…
sfmig Dec 2, 2022
85ff634
replaced lax slicing with regular slicing in healpix_fft_jax
sfmig Dec 7, 2022
9b36268
removed explicit conversion to DeviceArrays
sfmig Dec 7, 2022
c5873cc
replaced jnp.where() approach by nested lax.conds() in vmappable vers…
sfmig Dec 7, 2022
68ebf42
further jaxifying healpix_fft_jax and spectral_periodic_extension_jax…
sfmig Dec 13, 2022
f694ab6
in forward transform, replaced vmap approach with lax.map as suggeste…
sfmig Dec 13, 2022
824ec8e
refactored spectral_periodic_extension_jax following Matt G's suggestion
sfmig Dec 14, 2022
637ec2d
removed commented block
sfmig Dec 14, 2022
007c4dd
separated map and vmap implementations
sfmig Dec 15, 2022
cee9de6
three approaches to JAXifying healpix_fft. The lax.scan one (2) is st…
sfmig Dec 16, 2022
804bdfb
a notebook to compare healpix_fft JAX approaches
sfmig Dec 16, 2022
0908161
a vmappable way to compute the number of phi samples for healpix (unu…
sfmig Dec 16, 2022
d6ea835
removed some comments
sfmig Dec 16, 2022
ff0e845
keeping Healpix FFT JAX implementation using jax.numpy/numpy stack only
sfmig Dec 19, 2022
1dee88b
some notebooks to check against groundtruth
sfmig Dec 19, 2022
7f04e06
Merge branch 'main' of github.com:astro-informatics/s2fft into jax-fw…
sfmig Dec 19, 2022
e4f09a3
added numpy module to doc string in healpix_fft_jax
sfmig Dec 19, 2022
f76c81b
first attempt at adding reality to _compute_forward_sov_fft_vectorize…
sfmig Dec 19, 2022
59527f3
some work trying to add reality bits to turok_jax (in progress)
sfmig Dec 20, 2022
637e462
added reality option to turok_jax and forward transform
sfmig Jan 11, 2023
80b9204
changes dl computation to vmap along el dimension only to prevent oom…
sfmig Jan 11, 2023
833177c
replaced manual loop across theta in dl computation with a lax.scan, …
sfmig Jan 11, 2023
97fafc1
added implementation with vmap and manual loop over theta and refacto…
sfmig Jan 11, 2023
5480154
black formatting
sfmig Jan 11, 2023
49c3a0d
fixed phase shift in jax_vmap_loop, and added versions without .at in…
sfmig Jan 12, 2023
3d4811a
added double map implementation for dl with reality
sfmig Jan 12, 2023
4a675ac
added reality to map+scan approach
sfmig Jan 13, 2023
acc19cc
refactoring bits
sfmig Jan 13, 2023
e3afe81
fixed phase shift for healpix case in jax_vmap_loop and jax_vmap_loop_0
sfmig Jan 13, 2023
b0fdfa4
notebook to check jax implementations of fwd transform vs groundtruth
sfmig Jan 13, 2023
82b4dc4
factored out common bits of jax implementations under _compute_forwar…
sfmig Jan 13, 2023
755a575
refactored jax implementations further (same variable names and axes …
sfmig Jan 16, 2023
745ed4f
refactored notebook to check against ground truth
sfmig Jan 16, 2023
bcaa4e2
refactored notebook to check against ground truth
sfmig Jan 16, 2023
6455b6e
removed previous implementations from jax list in _forward
sfmig Jan 16, 2023
3627513
added tests for jax implementations
sfmig Jan 16, 2023
8385bd6
cosmetic changes
sfmig Jan 16, 2023
4bb85bb
added a test for JAX implementations with healpix sampling
sfmig Jan 16, 2023
57a4855
refactored supporting functions in samples and healpix_ffts
sfmig Jan 16, 2023
ee438fe
removed old notebooks. kept latest one comparing to groundtruth.
sfmig Jan 16, 2023
81c7a60
changed np.pi to the selected module (jnp or np)
sfmig Jan 16, 2023
84907ae
removed healpix_jax test using vectorized approach
sfmig Jan 16, 2023
d572797
removed jax methods from healpix tests
sfmig Jan 16, 2023
a3c4876
Apply suggestions from code review
sfmig Jan 23, 2023
8f2551e
added numpy-only original implementation of p2phi_ring function
sfmig Jan 23, 2023
77d1cea
small refactoring in the symmetry loop for readability
sfmig Jan 23, 2023
2c25b54
commented warning for invalid spin value since it will be lost in the…
sfmig Jan 23, 2023
d882d41
add list of supported JAX methods in the docstring for the method arg…
sfmig Jan 30, 2023
790f9f0
remove comment on alternative padding approach
sfmig Jan 30, 2023
b4cfde0
removed alternative phase_shift computation inside accumulate function
sfmig Jan 30, 2023
d72ab48
changed phase shift to be computed inside the loop to avoid unexpecte…
sfmig Jan 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions notebooks/notebook_fwd_jax_vs_gt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
import sys

sys.path.append("../")


# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

import numpy as np
from jax.config import config
import pyssht as ssht
import s2fft as s2f

config.update("jax_enable_x64", True)

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Compare All JAX methods to ground truth

# generate spherical harmonics (ground truth)
DEFAULT_SEED = 8966433580120847635
rn_gen = np.random.default_rng(DEFAULT_SEED) # modify to use JAX random key approach?

### input params
L_in = 6 # input L, redefined to 2*nside if sampling is healpix
spin = -2
reality = False
L_lower = 1
nside_healpix = 2

# list jax implementations
list_jax_str = [
"jax_vmap_double",
"jax_vmap_scan",
"jax_vmap_loop",
"jax_map_double",
"jax_map_scan",
]

# list of sampling approaches
list_sampling = ["mw", "mwss", "dh", "healpix"]

# Print inputs common to all sampling methods
print("-----------------------------------------")
print("Input params:")
print(
f"L_in = {L_in}, spin = {spin}, reality = {reality}, L_lower = {L_lower}, nside_healpix={nside_healpix}"
)
print("-----------------------------------------")

# All JAX methods
for m_str in list_jax_str:

for sampling in list_sampling:

# Groundtruth and starting point f
if sampling != "healpix":
# Set nside to None if not healpix
nside = None
L = L_in

# compute ground truth and starting point f
flm_gt = s2f.utils.generate_flm(rn_gen, L, L_lower, spin, reality=reality)
f = ssht.inverse(
s2f.samples.flm_2d_to_1d(flm_gt, L),
L,
Method=sampling.upper(),
Spin=spin,
Reality=False,
)

else:
# Set nside to healpix value and redefine L
nside = nside_healpix
L = 2 * nside # L redefined to double nside

# compute ground truth and starting point f
flm_gt0 = s2f.utils.generate_flm(rn_gen, L, L_lower, spin, reality=reality)
f = s2f.transform._inverse(
flm_gt0, L, spin, sampling=sampling, method="direct", nside=nside, reality=reality, L_lower=L_lower
)

# Compute numpy solution if sampling is healpix (used as GT)
flm_sov_fft_vec = s2f.transform._forward(
f,
L,
spin,
sampling,
method="sov_fft_vectorized",
nside=nside,
reality=reality,
L_lower=L_lower,
)

# JAX implementation
flm_sov_fft_vec_jax = s2f.transform._forward(
f,
L,
spin,
sampling,
method=m_str,
nside=nside,
reality=reality,
L_lower=L_lower,
)

# Compare to GT
if sampling == "healpix": # compare to numpy result rather than GT
print(
f"{m_str} vs numpy ({sampling}, L = {L}): {np.allclose(flm_sov_fft_vec, flm_sov_fft_vec_jax, atol=1e-14, rtol=1e-07)}"
)
else:
print(
f"{m_str} vs GT ({sampling}, L = {L}): {np.allclose(flm_gt, flm_sov_fft_vec_jax, atol=1e-14, rtol=1e-07)}" # rtol=1e-07 default in np.testing.assert_allclose
)

# %timeit s2f.transform._forward(f, L, spin, sampling, method=m_str, nside=nside, reality=reality, L_lower=L_lower)
# %timeit s2f.transform._forward(f, L, spin, sampling, method=m_str, nside=nside, reality=reality, L_lower=L_lower)
print("-----------------------------------------")


# %%
69 changes: 34 additions & 35 deletions s2fft/healpix_ffts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
import numpy.fft as fft
import s2fft.samples as samples

import jax
import jax.numpy as jnp
import jax.numpy.fft as jfft

from functools import partial


def spectral_folding(fm: np.ndarray, nphi: int, L: int) -> np.ndarray:
"""Folds higher frequency Fourier coefficients back onto lower frequency
Expand Down Expand Up @@ -35,41 +41,34 @@ def spectral_folding(fm: np.ndarray, nphi: int, L: int) -> np.ndarray:
return ftm_slice


def spectral_periodic_extension(fm: np.ndarray, nphi: int, L: int) -> np.ndarray:
def spectral_periodic_extension(fm: np.ndarray, L: int, numpy_module=np) -> np.ndarray:
"""Extends lower frequency Fourier coefficients onto higher frequency
coefficients, i.e. imposed periodicity in Fourier space.
coefficients, i.e. imposed periodicity in Fourier space.
Based on `spectral_periodic_extension`, modified to be JIT-compilable.

Args:
fm (np.ndarray): Slice of Fourier coefficients corresponding to ring at latitute t.

nphi (int): Total number of pixel space phi samples for latitude t.

L (int): Harmonic band-limit.

numpy_module: JAX's Numpy-like API or Numpy. Default Numpy.

Returns:
np.ndarray: Higher resolution set of periodic Fourier coefficients.
"""
assert nphi <= 2 * L

slice_start = L - nphi // 2
slice_stop = slice_start + nphi
fm_full = np.zeros(2 * L, dtype=np.complex128)
fm_full[slice_start:slice_stop] = fm

idx = 1
while slice_start - idx >= 0:
fm_full[slice_start - idx] = fm[-idx % nphi]
idx += 1
idx = 0
while slice_stop + idx < len(fm_full):
fm_full[slice_stop + idx] = fm[idx % nphi]
idx += 1

return fm_full


def healpix_fft(f: np.ndarray, L: int, nside: int) -> np.ndarray:
"""Computes the Forward Fast Fourier Transform with spectral back-projection
nphi = fm.shape[0]
return numpy_module.concatenate(
(
fm[-numpy_module.arange(L - nphi // 2, 0, -1) % nphi],
fm,
fm[numpy_module.arange(L - (nphi + 1) // 2) % nphi]
)
)


def healpix_fft(f: np.ndarray, L: int, nside: int, numpy_module=np) -> np.ndarray:
'''
Computes the Forward Fast Fourier Transform with spectral back-projection
in the polar regions to manually enforce Fourier periodicity.

Args:
Expand All @@ -79,24 +78,24 @@ def healpix_fft(f: np.ndarray, L: int, nside: int) -> np.ndarray:

nside (int): HEALPix Nside resolution parameter.

numpy_module: JAX's Numpy-like API or Numpy. Default Numpy.

Returns:
np.ndarray: Array of Fourier coefficients for all latitudes.
"""
assert L >= 2 * nside
'''

assert L >= 2 * nside
ntheta = samples.ntheta(L, "healpix", nside)
index = 0
ftm = np.zeros(samples.ftm_shape(L, "healpix", nside), dtype=np.complex128)
ntheta = ftm.shape[0]
ftm_rows = []
for t in range(ntheta):
nphi = samples.nphi_ring(t, nside)
fm_chunk = fft.fftshift(fft.fft(f[index : index + nphi], norm="backward"))
ftm[t] = (
fm_chunk
if nphi == 2 * L
else spectral_periodic_extension(fm_chunk, nphi, L)
fm_chunk = numpy_module.fft.fftshift(
numpy_module.fft.fft(f[index : index + nphi], norm="backward")
)
ftm_rows.append(spectral_periodic_extension(fm_chunk, L, numpy_module))
index += nphi
return ftm
return numpy_module.stack(ftm_rows)


def healpix_ifft(ftm: np.ndarray, L: int, nside: int) -> np.ndarray:
Expand Down
96 changes: 88 additions & 8 deletions s2fft/samples.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np

import jax.numpy as jnp
import jax.lax as lax

def ntheta(L: int = None, sampling: str = "mw", nside: int = None) -> int:
r"""Number of :math:`\theta` samples for sampling scheme at specified resolution.
Expand Down Expand Up @@ -332,9 +333,50 @@ def phis_ring(t: int, nside: int) -> np.ndarray:
return p2phi_ring(t, p, nside)


def p2phi_ring(t: int, p: int, nside: int) -> np.ndarray:
def p2phi_ring(t: int, p: int, nside: int, numpy_module=np):
r"""Convert index to :math:`\phi` angle for HEALPix for given :math:`\theta` ring.

See 'p2phi_ring_jax_lax' for an alternative JAX-only implementation using nested lax.cond's,
and 'p2phi_ring_np' for an alternative Numpy-only implementation.

Args:
t (int): :math:`\theta` index of ring.

p (int): :math:`\phi` index within ring.

nside (int): HEALPix Nside resolution parameter.

numpy_module: JAX's Numpy-like API or Numpy. Default Numpy.

Returns:
np.ndarray: :math:`\phi` angle.
"""

# shift per region (2 regions)
shift = numpy_module.where(
(t + 1 >= nside) & (t + 1 <= 3 * nside), (1 / 2) * ((t - nside + 2) % 2), 1 / 2
)

# factor per region (3 regions)
factor_reg_1 = numpy_module.where(
(t + 1 >= nside) & (t + 1 <= 3 * nside), numpy_module.pi / (2 * nside), 1
)
factor_reg_2 = numpy_module.where(t + 1 > 3 * nside, numpy_module.pi / (2 * (4 * nside - t - 1)), 1)
factor_reg_3 = numpy_module.where(
(factor_reg_1 == 1) & (factor_reg_2 == 1), numpy_module.pi / (2 * (t + 1)), 1
)
factor = (
factor_reg_1 * factor_reg_2 * factor_reg_3
)

return factor * (p + shift)

def p2phi_ring_np(t: int, p: int, nside: int):
r"""Convert index to :math:`\phi` angle for HEALPix for given :math:`\theta` ring - Numpy-only implementation.

See 'p2phi_ring_jax_lax' for an alternative JAX implementation using nested lax.cond's, and 'p2phi_ring' for a
JAX/Numpy implementation using jnp.where/np.where.

Args:
t (int): :math:`\theta` index of ring.

Expand All @@ -350,11 +392,48 @@ def p2phi_ring(t: int, p: int, nside: int) -> np.ndarray:
if (t + 1 >= nside) & (t + 1 <= 3 * nside):
shift *= (t - nside + 2) % 2
factor = np.pi / (2 * nside)
return factor * (p + shift)
elif t + 1 > 3 * nside:
factor = np.pi / (2 * (4 * nside - t - 1))
else:
factor = np.pi / (2 * (t + 1))

return factor * (p + shift)

def p2phi_ring_jax_lax(t: int, p: int, nside: int) -> jnp.ndarray:
r"""Convert index to :math:`\phi` angle for HEALPix for given :math:`\theta` ring - JAX-only implementation.

Uses nested lax.cond from JAX. See 'p2phi_ring' for an alternative JAX/Numpy implementation using
jnp.where/np.where, and 'p2phi_ring_np' for a Numpy-only implementation.

Args:
t (int): :math:`\theta` index of ring.

p (int): :math:`\phi` index within ring.

nside (int): HEALPix Nside resolution parameter.

Returns:
jnp.ndarray: :math:`\phi` angle.
"""

# using nested lax.cond
shift = lax.cond((t + 1 >= nside) & (t + 1 <= 3 * nside),
lambda t: (1 / 2) * ((t - nside + 2) % 2),
lambda t: 1/2,
t)

factor = lax.cond(
(t + 1 >= nside) & (t + 1 <= 3 * nside),
lambda x: jnp.pi / (2 * nside),
lambda x: lax.cond(
x + 1 > 3 * nside,
lambda x: jnp.pi / (2 * (4 * nside - x - 1)),
lambda x: jnp.pi / (2 * (x + 1)),
x),
t)



return factor * (p + shift)


Expand Down Expand Up @@ -417,8 +496,8 @@ def p2phi_equiang(L: int, p: int, sampling: str = "mw") -> np.ndarray:


def ring_phase_shift_hp(
L: int, t: int, nside: int, forward: bool = False
) -> np.ndarray:
L: int, t: int, nside: int, forward: bool = False, numpy_module=np
):
r"""Generates a phase shift vector for HEALPix for a given :math:`\theta` ring.

Args:
Expand All @@ -431,13 +510,14 @@ def ring_phase_shift_hp(
forward (bool, optional): Whether to provide forward or inverse shift.
Defaults to False.

numpy_module: JAX's Numpy-like API or Numpy. Default Numpy.

Returns:
np.ndarray: Vector of phase shifts with shape :math:`[2L-1]`.
"""
phi_offset = p2phi_ring(t, 0, nside)
phi_offset = p2phi_ring(t, 0, nside, numpy_module)
sign = -1 if forward else 1
return np.exp(sign * 1j * np.arange(-L + 1, L) * phi_offset)

return numpy_module.exp(sign * 1j * numpy_module.arange(-L + 1, L) * phi_offset)

def f_shape(L: int = None, sampling: str = "mw", nside: int = None) -> tuple:
"""Shape of spherical signal.
Expand Down
Loading