diff --git a/notebooks/notebook_fwd_jax_vs_gt.py b/notebooks/notebook_fwd_jax_vs_gt.py new file mode 100644 index 00000000..8f5a778b --- /dev/null +++ b/notebooks/notebook_fwd_jax_vs_gt.py @@ -0,0 +1,121 @@ +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +import sys + +sys.path.append("../") + + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +import numpy as np +from jax.config import config +import pyssht as ssht +import s2fft as s2f + +config.update("jax_enable_x64", True) + +# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +# Compare All JAX methods to ground truth + +# generate spherical harmonics (ground truth) +DEFAULT_SEED = 8966433580120847635 +rn_gen = np.random.default_rng(DEFAULT_SEED) # modify to use JAX random key approach? + +### input params +L_in = 6 # input L, redefined to 2*nside if sampling is healpix +spin = -2 +reality = False +L_lower = 1 +nside_healpix = 2 + +# list jax implementations +list_jax_str = [ + "jax_vmap_double", + "jax_vmap_scan", + "jax_vmap_loop", + "jax_map_double", + "jax_map_scan", +] + +# list of sampling approaches +list_sampling = ["mw", "mwss", "dh", "healpix"] + +# Print inputs common to all sampling methods +print("-----------------------------------------") +print("Input params:") +print( + f"L_in = {L_in}, spin = {spin}, reality = {reality}, L_lower = {L_lower}, nside_healpix={nside_healpix}" +) +print("-----------------------------------------") + +# All JAX methods +for m_str in list_jax_str: + + for sampling in list_sampling: + + # Groundtruth and starting point f + if sampling != "healpix": + # Set nside to None if not healpix + nside = None + L = L_in + + # compute ground truth and starting point f + flm_gt = s2f.utils.generate_flm(rn_gen, L, L_lower, spin, reality=reality) + f = ssht.inverse( + s2f.samples.flm_2d_to_1d(flm_gt, L), + L, + Method=sampling.upper(), + Spin=spin, + Reality=False, + ) + + else: + # Set nside to healpix value and redefine L + nside = nside_healpix + L = 2 * nside # L redefined to double nside + + # compute ground truth and starting point f + flm_gt0 = s2f.utils.generate_flm(rn_gen, L, L_lower, spin, reality=reality) + f = s2f.transform._inverse( + flm_gt0, L, spin, sampling=sampling, method="direct", nside=nside, reality=reality, L_lower=L_lower + ) + + # Compute numpy solution if sampling is healpix (used as GT) + flm_sov_fft_vec = s2f.transform._forward( + f, + L, + spin, + sampling, + method="sov_fft_vectorized", + nside=nside, + reality=reality, + L_lower=L_lower, + ) + + # JAX implementation + flm_sov_fft_vec_jax = s2f.transform._forward( + f, + L, + spin, + sampling, + method=m_str, + nside=nside, + reality=reality, + L_lower=L_lower, + ) + + # Compare to GT + if sampling == "healpix": # compare to numpy result rather than GT + print( + f"{m_str} vs numpy ({sampling}, L = {L}): {np.allclose(flm_sov_fft_vec, flm_sov_fft_vec_jax, atol=1e-14, rtol=1e-07)}" + ) + else: + print( + f"{m_str} vs GT ({sampling}, L = {L}): {np.allclose(flm_gt, flm_sov_fft_vec_jax, atol=1e-14, rtol=1e-07)}" # rtol=1e-07 default in np.testing.assert_allclose + ) + + # %timeit s2f.transform._forward(f, L, spin, sampling, method=m_str, nside=nside, reality=reality, L_lower=L_lower) + # %timeit s2f.transform._forward(f, L, spin, sampling, method=m_str, nside=nside, reality=reality, L_lower=L_lower) + print("-----------------------------------------") + + +# %% diff --git a/s2fft/healpix_ffts.py b/s2fft/healpix_ffts.py index 25c3c053..8ad8ac51 100644 --- a/s2fft/healpix_ffts.py +++ b/s2fft/healpix_ffts.py @@ -2,6 +2,12 @@ import numpy.fft as fft import s2fft.samples as samples +import jax +import jax.numpy as jnp +import jax.numpy.fft as jfft + +from functools import partial + def spectral_folding(fm: np.ndarray, nphi: int, L: int) -> np.ndarray: """Folds higher frequency Fourier coefficients back onto lower frequency @@ -35,41 +41,34 @@ def spectral_folding(fm: np.ndarray, nphi: int, L: int) -> np.ndarray: return ftm_slice -def spectral_periodic_extension(fm: np.ndarray, nphi: int, L: int) -> np.ndarray: +def spectral_periodic_extension(fm: np.ndarray, L: int, numpy_module=np) -> np.ndarray: """Extends lower frequency Fourier coefficients onto higher frequency - coefficients, i.e. imposed periodicity in Fourier space. + coefficients, i.e. imposed periodicity in Fourier space. + Based on `spectral_periodic_extension`, modified to be JIT-compilable. Args: fm (np.ndarray): Slice of Fourier coefficients corresponding to ring at latitute t. - nphi (int): Total number of pixel space phi samples for latitude t. - L (int): Harmonic band-limit. + + numpy_module: JAX's Numpy-like API or Numpy. Default Numpy. Returns: np.ndarray: Higher resolution set of periodic Fourier coefficients. """ - assert nphi <= 2 * L - - slice_start = L - nphi // 2 - slice_stop = slice_start + nphi - fm_full = np.zeros(2 * L, dtype=np.complex128) - fm_full[slice_start:slice_stop] = fm - - idx = 1 - while slice_start - idx >= 0: - fm_full[slice_start - idx] = fm[-idx % nphi] - idx += 1 - idx = 0 - while slice_stop + idx < len(fm_full): - fm_full[slice_stop + idx] = fm[idx % nphi] - idx += 1 - - return fm_full - - -def healpix_fft(f: np.ndarray, L: int, nside: int) -> np.ndarray: - """Computes the Forward Fast Fourier Transform with spectral back-projection + nphi = fm.shape[0] + return numpy_module.concatenate( + ( + fm[-numpy_module.arange(L - nphi // 2, 0, -1) % nphi], + fm, + fm[numpy_module.arange(L - (nphi + 1) // 2) % nphi] + ) + ) + + +def healpix_fft(f: np.ndarray, L: int, nside: int, numpy_module=np) -> np.ndarray: + ''' + Computes the Forward Fast Fourier Transform with spectral back-projection in the polar regions to manually enforce Fourier periodicity. Args: @@ -79,24 +78,24 @@ def healpix_fft(f: np.ndarray, L: int, nside: int) -> np.ndarray: nside (int): HEALPix Nside resolution parameter. + numpy_module: JAX's Numpy-like API or Numpy. Default Numpy. + Returns: np.ndarray: Array of Fourier coefficients for all latitudes. - """ - assert L >= 2 * nside + ''' + assert L >= 2 * nside + ntheta = samples.ntheta(L, "healpix", nside) index = 0 - ftm = np.zeros(samples.ftm_shape(L, "healpix", nside), dtype=np.complex128) - ntheta = ftm.shape[0] + ftm_rows = [] for t in range(ntheta): nphi = samples.nphi_ring(t, nside) - fm_chunk = fft.fftshift(fft.fft(f[index : index + nphi], norm="backward")) - ftm[t] = ( - fm_chunk - if nphi == 2 * L - else spectral_periodic_extension(fm_chunk, nphi, L) + fm_chunk = numpy_module.fft.fftshift( + numpy_module.fft.fft(f[index : index + nphi], norm="backward") ) + ftm_rows.append(spectral_periodic_extension(fm_chunk, L, numpy_module)) index += nphi - return ftm + return numpy_module.stack(ftm_rows) def healpix_ifft(ftm: np.ndarray, L: int, nside: int) -> np.ndarray: diff --git a/s2fft/samples.py b/s2fft/samples.py index 2deeeed0..c48bbc77 100644 --- a/s2fft/samples.py +++ b/s2fft/samples.py @@ -1,5 +1,6 @@ import numpy as np - +import jax.numpy as jnp +import jax.lax as lax def ntheta(L: int = None, sampling: str = "mw", nside: int = None) -> int: r"""Number of :math:`\theta` samples for sampling scheme at specified resolution. @@ -332,9 +333,50 @@ def phis_ring(t: int, nside: int) -> np.ndarray: return p2phi_ring(t, p, nside) -def p2phi_ring(t: int, p: int, nside: int) -> np.ndarray: +def p2phi_ring(t: int, p: int, nside: int, numpy_module=np): r"""Convert index to :math:`\phi` angle for HEALPix for given :math:`\theta` ring. + See 'p2phi_ring_jax_lax' for an alternative JAX-only implementation using nested lax.cond's, + and 'p2phi_ring_np' for an alternative Numpy-only implementation. + + Args: + t (int): :math:`\theta` index of ring. + + p (int): :math:`\phi` index within ring. + + nside (int): HEALPix Nside resolution parameter. + + numpy_module: JAX's Numpy-like API or Numpy. Default Numpy. + + Returns: + np.ndarray: :math:`\phi` angle. + """ + + # shift per region (2 regions) + shift = numpy_module.where( + (t + 1 >= nside) & (t + 1 <= 3 * nside), (1 / 2) * ((t - nside + 2) % 2), 1 / 2 + ) + + # factor per region (3 regions) + factor_reg_1 = numpy_module.where( + (t + 1 >= nside) & (t + 1 <= 3 * nside), numpy_module.pi / (2 * nside), 1 + ) + factor_reg_2 = numpy_module.where(t + 1 > 3 * nside, numpy_module.pi / (2 * (4 * nside - t - 1)), 1) + factor_reg_3 = numpy_module.where( + (factor_reg_1 == 1) & (factor_reg_2 == 1), numpy_module.pi / (2 * (t + 1)), 1 + ) + factor = ( + factor_reg_1 * factor_reg_2 * factor_reg_3 + ) + + return factor * (p + shift) + +def p2phi_ring_np(t: int, p: int, nside: int): + r"""Convert index to :math:`\phi` angle for HEALPix for given :math:`\theta` ring - Numpy-only implementation. + + See 'p2phi_ring_jax_lax' for an alternative JAX implementation using nested lax.cond's, and 'p2phi_ring' for a + JAX/Numpy implementation using jnp.where/np.where. + Args: t (int): :math:`\theta` index of ring. @@ -350,11 +392,48 @@ def p2phi_ring(t: int, p: int, nside: int) -> np.ndarray: if (t + 1 >= nside) & (t + 1 <= 3 * nside): shift *= (t - nside + 2) % 2 factor = np.pi / (2 * nside) - return factor * (p + shift) elif t + 1 > 3 * nside: factor = np.pi / (2 * (4 * nside - t - 1)) else: factor = np.pi / (2 * (t + 1)) + + return factor * (p + shift) + +def p2phi_ring_jax_lax(t: int, p: int, nside: int) -> jnp.ndarray: + r"""Convert index to :math:`\phi` angle for HEALPix for given :math:`\theta` ring - JAX-only implementation. + + Uses nested lax.cond from JAX. See 'p2phi_ring' for an alternative JAX/Numpy implementation using + jnp.where/np.where, and 'p2phi_ring_np' for a Numpy-only implementation. + + Args: + t (int): :math:`\theta` index of ring. + + p (int): :math:`\phi` index within ring. + + nside (int): HEALPix Nside resolution parameter. + + Returns: + jnp.ndarray: :math:`\phi` angle. + """ + + # using nested lax.cond + shift = lax.cond((t + 1 >= nside) & (t + 1 <= 3 * nside), + lambda t: (1 / 2) * ((t - nside + 2) % 2), + lambda t: 1/2, + t) + + factor = lax.cond( + (t + 1 >= nside) & (t + 1 <= 3 * nside), + lambda x: jnp.pi / (2 * nside), + lambda x: lax.cond( + x + 1 > 3 * nside, + lambda x: jnp.pi / (2 * (4 * nside - x - 1)), + lambda x: jnp.pi / (2 * (x + 1)), + x), + t) + + + return factor * (p + shift) @@ -417,8 +496,8 @@ def p2phi_equiang(L: int, p: int, sampling: str = "mw") -> np.ndarray: def ring_phase_shift_hp( - L: int, t: int, nside: int, forward: bool = False -) -> np.ndarray: + L: int, t: int, nside: int, forward: bool = False, numpy_module=np +): r"""Generates a phase shift vector for HEALPix for a given :math:`\theta` ring. Args: @@ -431,13 +510,14 @@ def ring_phase_shift_hp( forward (bool, optional): Whether to provide forward or inverse shift. Defaults to False. + numpy_module: JAX's Numpy-like API or Numpy. Default Numpy. + Returns: np.ndarray: Vector of phase shifts with shape :math:`[2L-1]`. """ - phi_offset = p2phi_ring(t, 0, nside) + phi_offset = p2phi_ring(t, 0, nside, numpy_module) sign = -1 if forward else 1 - return np.exp(sign * 1j * np.arange(-L + 1, L) * phi_offset) - + return numpy_module.exp(sign * 1j * numpy_module.arange(-L + 1, L) * phi_offset) def f_shape(L: int = None, sampling: str = "mw", nside: int = None) -> tuple: """Shape of spherical signal. diff --git a/s2fft/transform.py b/s2fft/transform.py index fdc84c0a..ea7d313a 100644 --- a/s2fft/transform.py +++ b/s2fft/transform.py @@ -1,11 +1,18 @@ +from functools import partial +from warnings import warn + +import jax +import jax.numpy as jnp +import jax.numpy.fft as jfft import numpy as np import numpy.fft as fft -from warnings import warn -import s2fft.samples as samples +from jax import jit + +import s2fft.healpix_ffts as hp import s2fft.quadrature as quadrature import s2fft.resampling as resampling +import s2fft.samples as samples import s2fft.wigner as wigner -import s2fft.healpix_ffts as hp def inverse( @@ -200,8 +207,9 @@ def _forward( {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". method (str, optional): Harmonic transform algorithm. Supported algorithms include - {"direct", "sov", "sov_fft", "sov_fft_vectorized"}. Defaults to - "sov_fft_vectorized". + {"direct", "sov", "sov_fft", "sov_fft_vectorized"} and a set of exploratory JAX + implementations: {"jax_vmap_double", "jax_vmap_scan", "jax_vmap_loop", "jax_map_double", + "jax_map_scan"}. Defaults to "sov_fft_vectorized". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -246,12 +254,26 @@ def _forward( # since accounted for already in periodic extension and upsampling. weights = quadrature.quad_weights_transform(L, sampling, 0, nside) - transform_methods = { - "direct": _compute_forward_direct, - "sov": _compute_forward_sov, - "sov_fft": _compute_forward_sov_fft, - "sov_fft_vectorized": _compute_forward_sov_fft_vectorized, - } + list_jax_methods = [ + "jax_vmap_double", + "jax_vmap_scan", + "jax_vmap_loop", + "jax_map_double", + "jax_map_scan", + ] + transform_methods = dict( + { + "direct": _compute_forward_direct, + "sov": _compute_forward_sov, + "sov_fft": _compute_forward_sov_fft, + "sov_fft_vectorized": _compute_forward_sov_fft_vectorized, + }, + **{ + jx_str: partial(_compute_forward_jax, jax_method=jx_str) + for jx_str in list_jax_methods + } + ) + return transform_methods[method]( f, L, @@ -937,6 +959,7 @@ def _compute_forward_sov_fft_vectorized( ftm = np.zeros_like(f).astype(np.complex128) m_offset = 1 if sampling in ["mwss", "healpix"] else 0 + if reality: m_conj = (-1) ** (np.arange(1, L) % 2) @@ -955,6 +978,8 @@ def _compute_forward_sov_fft_vectorized( else: ftm = fft.fftshift(fft.fft(f, axis=1, norm="backward"), axes=1) + m_start_ind = L - 1 if reality else 0 + for t, theta in enumerate(thetas): phase_shift = ( @@ -969,7 +994,6 @@ def _compute_forward_sov_fft_vectorized( elfactor = np.sqrt((2 * el + 1) / (4 * np.pi)) - m_start_ind = L - 1 if reality else 0 flm[el, m_start_ind:] += ( weights[t] * elfactor @@ -987,3 +1011,616 @@ def _compute_forward_sov_fft_vectorized( flm *= (-1) ** spin return flm + + +@partial(jit, static_argnums=(1, 2, 3, 6, 7, 8, 9)) +def _compute_forward_jax( + f: np.ndarray, + L: int, + spin: int, + sampling: str, + thetas: np.ndarray, + weights: np.ndarray, + nside: int, + reality: bool, + L_lower: int, + jax_method: str, +): + r"""Compute forward spherical harmonic transform using a specified JAX method. + + Args: + f (np.ndarray): Signal on the sphere. + + L (int): Harmonic band-limit. + + spin (int): Harmonic spin. + + sampling (str): Sampling scheme. Supported sampling schemes include + {"mw", "mwss", "dh", "healpix"}. + + thetas (np.ndarray): Vector of sample positions in :math:`\theta` on the sphere. + + weights (np.ndarray): Vector of quadrature weights on the sphere. + + nside (int): HEALPix Nside resolution parameter. Only required + if sampling="healpix". Defaults to None. + + reality (bool): Whether the signal on the sphere is real. If so, + conjugate symmetry is exploited to reduce computational costs. + + L_lower (int): Harmonic lower-bound. Transform will only be computed + for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. Defaults to 0. + + jax_method (str): Harmonic transform algorithm based on JAX functions. + They are all based on "sov_fft_vectorized", a vectorized function to + compute forward spherical harmonic transform by separation of variables + with a manual Fourier transform. Supported algorithms include + {"jax_vmap_double", "jax_vmap_scan", "jax_vmap_loop", + "jax_map_double", "jax_map_scan""}. + + + Returns: + np.ndarray: Spherical harmonic coefficients. + """ + + # m offset + m_offset = 1 if sampling in ["mwss", "healpix"] else 0 + + # ftm array + if sampling.lower() == "healpix": + ftm = hp.healpix_fft(f, L, nside, jnp) + else: + if reality: + t = jfft.rfft( + jnp.real(f), + axis=1, + norm="backward", + ) + if m_offset != 0: + t = t[:, :-1] + + # build ftm by concatenating t with a zero array + ftm = jnp.hstack( + (jnp.zeros((f.shape[0], L - 1 + m_offset)).astype(jnp.complex128), t), + ) + else: + ftm = jfft.fftshift(jfft.fft(f, axis=1, norm="backward"), axes=1) + + # els, elfactor + els = jnp.arange(max(L_lower, abs(spin)), L) + elfactors = jnp.sqrt((2 * els + 1) / (4 * jnp.pi)) + + # m_start_ind + m_start_ind = L - 1 if reality else 0 + + # Compute flm using the selected approach---change name to fn? + flm_methods = { + "jax_vmap_double": _compute_flm_vmap_double, + "jax_vmap_scan": _compute_flm_vmap_scan, + "jax_vmap_loop": _compute_flm_vmap_loop, + "jax_map_double": _compute_flm_map_double, + "jax_map_scan": _compute_flm_map_scan, + } + flm = flm_methods[jax_method]( + L, + spin, + sampling, + thetas, + weights, + nside, + reality, + L_lower, + m_offset, + ftm, + els, + elfactors, + m_start_ind, + ) + + # Apply spin + flm *= (-1) ** spin + + # Mask after pad (to set spurious results from wigner.turok_jax.compute_slice to zero) + upper_diag = jnp.triu(jnp.ones_like(flm, dtype=bool).T, k=-(L - 1)).T + mask = upper_diag * jnp.fliplr(upper_diag) + flm *= mask + + # if reality=True: fill the first half of the columns w/ conjugate symmetric values + if reality: + # m conj + m_conj = (-1) ** (jnp.arange(1, L) % 2) + + flm = flm.at[:, :m_start_ind].set( + jnp.flip(m_conj * jnp.conj(flm[:, m_start_ind + 1 :]), axis=-1) + ) + + return flm + + +def _compute_flm_vmap_double( + L: int, + spin: int, + sampling: str, + thetas: np.ndarray, + weights: np.ndarray, + nside: int, + reality: bool, + L_lower: int, + m_offset: int, + ftm: np.ndarray, + els: np.ndarray, + elfactors: np.ndarray, + m_start_ind: int, +): + r"""Compute forward spherical harmonic transform using the 'double vmap' JAX method. + + We use the JAX `vmap` function to sweep across :math:`\theta` and :math:`\ell` (`el`). + Specifically, we compute the complete Wigner-d matrix (`dl`) as a 3D array by vmapping + the Turok & Bucher recursion twice, first along :math:`\theta` and then along :math:`\ell` (`el`). + The spherical harmonic coefficients are computed as the 2D array `flm` that results + from summing along the last dimension (:math:`\theta`) of the product 3D array. + + Args: + + L (int): Harmonic band-limit. + + spin (int): Harmonic spin. + + sampling (str): Sampling scheme. Supported sampling schemes include + {"mw", "mwss", "dh", "healpix"}. + + thetas (np.ndarray): Vector of sample positions in :math:`\theta` on the sphere. + + weights (np.ndarray): Vector of quadrature weights on the sphere. + + nside (int): HEALPix Nside resolution parameter. Only required + if sampling="healpix". Defaults to None. + + reality (bool): Whether the signal on the sphere is real. If so, + conjugate symmetry is exploited to reduce computational costs. + + L_lower (int): Harmonic lower-bound. Transform will only be computed + for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. + + m_offset (int): set to 1 if sampling in ["mwss", "healpix"], + otherwise set to 0. + + ftm (np.ndarray): fast Fourier transform matrix. + + els (np.ndarray): vector of `el` values + + elfactors (np.ndarray): vector of `el` factors, computed as + :math:`\sqrt{\frac{2el + 1}{4\pi}}` + + m_start_ind (int): column offset if reality = True (set to L-1) + + Returns: + np.ndarray: Spherical harmonic coefficients. + """ + + # phase shifts per theta + if sampling.lower() != "healpix": + phase_shifts = jnp.array([[1.0]]) + else: + phase_shifts = jax.vmap( + samples.ring_phase_shift_hp, + in_axes=(None, 0, None, None, None), + out_axes=-1, # theta along last dimension + )(L, jnp.arange(len(thetas)), nside, True, jnp) + + # dl vmapped function (double vmap along theta and el) + dl_fn = jax.vmap( + jax.vmap( + wigner.turok_jax.compute_slice, + in_axes=(0, None, None, None, None), + out_axes=-1, # theta along last dimension + ), + in_axes=(None, 0, None, None, None), + out_axes=0, # el along first dimension + ) + + # flm + flm = jnp.zeros( + (*samples.flm_shape(L), len(thetas)), # 3D array (L, 2L-1, ntheta) + dtype=jnp.complex128, + ) + flm = ( + flm.at[max(L_lower, abs(spin)) :, m_start_ind:, :] + .set( + weights # [None,None,:] + * elfactors[:, None, None] + * dl_fn(thetas, els, L, -spin, reality)[:, m_start_ind:, :] + * ftm[:, m_start_ind + m_offset : 2 * L - 1 + m_offset, None].T + * phase_shifts[None, :, :] + ) + .sum(axis=-1) + ) + + return flm + + +def _compute_flm_vmap_scan( + L: int, + spin: int, + sampling: str, + thetas: np.ndarray, + weights: np.ndarray, + nside: int, + reality: bool, + L_lower: int, + m_offset: int, + ftm: np.ndarray, + els: np.ndarray, + elfactors: np.ndarray, + m_start_ind: int, +): + r"""Compute forward spherical harmonic transform using the 'vmap + scan' JAX method. + + We use the JAX `vmap` function to sweep across :math:`\ell` (`el`) and the JAX `lax.scan` + function to sweep across :math:`\theta`. + Specifically, we compute the complete Wigner-d matrix (`dl`) as a 2D array defined for each + :math:`\theta` value, by vmapping the Turok & Bucher recursion along :math:`\ell` (`el`). + The spherical harmonic coefficients are computed as the 2D array `flm` that results + from scanning along :math:`\theta` values and accumulating the result of the product. + + Args: + + L (int): Harmonic band-limit. + + spin (int): Harmonic spin. + + sampling (str): Sampling scheme. Supported sampling schemes include + {"mw", "mwss", "dh", "healpix"}. + + thetas (np.ndarray): Vector of sample positions in :math:`\theta` on the sphere. + + weights (np.ndarray): Vector of quadrature weights on the sphere. + + nside (int): HEALPix Nside resolution parameter. Only required + if sampling="healpix". Defaults to None. + + reality (bool): Whether the signal on the sphere is real. If so, + conjugate symmetry is exploited to reduce computational costs. + + L_lower (int): Harmonic lower-bound. Transform will only be computed + for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. + + m_offset (int): set to 1 if sampling in ["mwss", "healpix"], + otherwise set to 0. + + ftm (np.ndarray): fast Fourier transform matrix. + + els (np.ndarray): vector of `el` values + + elfactors (np.ndarray): vector of `el` factors, computed as + :math:`\sqrt{\frac{2el + 1}{4\pi}}` + + m_start_ind (int): column offset if reality = True (set to L-1) + + Returns: + np.ndarray: Spherical harmonic coefficients. + """ + + # phase shift + if sampling.lower() != "healpix": + phase_shifts = jnp.ones_like(thetas) + else: + phase_shifts = jax.vmap( + samples.ring_phase_shift_hp, + in_axes=(None, 0, None, None, None), + out_axes=0, # theta along first dimension + )(L, jnp.arange(len(thetas)), nside, True, jnp) + + # dl vmapped function (single vmap along el) + dl_vmapped = jax.vmap( + wigner.turok_jax.compute_slice, + in_axes=(None, 0, None, None, None), # vmap along el + out_axes=0, # el along first dimension + ) + + # flm (lax.scan over theta) + flm = jnp.zeros(samples.flm_shape(L), dtype=jnp.complex128) # (L, 2L-1) + + def accumulate(flm_carry, theta_weight_ftm_phase_shift): + theta, weight, ftm_slice, phase_shift = theta_weight_ftm_phase_shift + flm_carry = flm_carry.at[max(L_lower, abs(spin)) :, m_start_ind:].add( + weight + * elfactors[:, None] + * dl_vmapped(theta, els, L, -spin, reality)[:, m_start_ind:] + * ftm_slice[None, m_start_ind + m_offset : 2 * L - 1 + m_offset] + * phase_shift + ) + return flm_carry, None + + flm, _ = jax.lax.scan(accumulate, flm, (thetas, weights, ftm, phase_shifts)) + + return flm + + +def _compute_flm_vmap_loop( + L: int, + spin: int, + sampling: str, + thetas: np.ndarray, + weights: np.ndarray, + nside: int, + reality: bool, + L_lower: int, + m_offset: int, + ftm: np.ndarray, + els: np.ndarray, + elfactors: np.ndarray, + m_start_ind: int, +): + r"""Compute forward spherical harmonic transform using the 'vmap + loop' JAX method. + + We use the JAX `vmap` function to sweep across :math:`\ell` (`el`) and a regular Python + loop to sweep across :math:`\theta`. + Specifically, we compute the complete Wigner-d matrix (`dl`) as a 2D array defined for each + :math:`\theta` value, by vmapping the Turok & Bucher recursion along :math:`\ell` (`el`). + The spherical harmonic coefficients are computed as the 2D array `flm` that results + from looping along :math:`\theta` values and accumulating the result of the product. + + Args: + + L (int): Harmonic band-limit. + + spin (int): Harmonic spin. + + sampling (str): Sampling scheme. Supported sampling schemes include + {"mw", "mwss", "dh", "healpix"}. + + thetas (np.ndarray): Vector of sample positions in :math:`\theta` on the sphere. + + weights (np.ndarray): Vector of quadrature weights on the sphere. + + nside (int): HEALPix Nside resolution parameter. Only required + if sampling="healpix". Defaults to None. + + reality (bool): Whether the signal on the sphere is real. If so, + conjugate symmetry is exploited to reduce computational costs. + + L_lower (int): Harmonic lower-bound. Transform will only be computed + for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. + + m_offset (int): set to 1 if sampling in ["mwss", "healpix"], + otherwise set to 0. + + ftm (np.ndarray): fast Fourier transform matrix. + + els (np.ndarray): vector of `el` values + + elfactors (np.ndarray): vector of `el` factors, computed as + :math:`\sqrt{\frac{2el + 1}{4\pi}}` + + m_start_ind (int): column offset if reality = True (set to L-1) + + Returns: + np.ndarray: Spherical harmonic coefficients. + """ + + # dl vmapped function (single vmap along el) + dl_vmapped = jax.vmap( + wigner.turok_jax.compute_slice, + in_axes=(None, 0, None, None, None), # vmap along el + out_axes=0, # el along first dimension + ) + + # flm (Python loop over theta) + flm = jnp.zeros(samples.flm_shape(L), dtype=jnp.complex128) # (L, 2L-1) + for ti, theta in enumerate(thetas): + flm = flm.at[max(L_lower, abs(spin)) :, m_start_ind:].add( + weights[ti] + * elfactors[:, None] + * dl_vmapped(theta, els, L, -spin, reality)[:, m_start_ind:] + * ftm[ti, m_start_ind + m_offset : 2 * L - 1 + m_offset][:, None].T + * ( + samples.ring_phase_shift_hp(L, ti, nside, True, jnp) + if sampling.lower() == "healpix" + else 1.0 + ) + ) + + return flm + + +def _compute_flm_map_double( + L: int, + spin: int, + sampling: str, + thetas: np.ndarray, + weights: np.ndarray, + nside: int, + reality: bool, + L_lower: int, + m_offset: int, + ftm: np.ndarray, + els: np.ndarray, + elfactors: np.ndarray, + m_start_ind: int, +): + r"""Compute forward spherical harmonic transform using the 'double map' JAX method. + + We use the JAX `lax.map` function to sweep across :math:`\theta` and :math:`\ell` (`el`). + Specifically, we compute the complete Wigner-d matrix (`dl`) as a 3D array by mapping + the Turok & Bucher recursion twice, first along :math:`\theta` and then along :math:`\ell` (`el`). + The spherical harmonic coefficients are computed as the 2D array `flm` that results + from summing along the last dimension (:math:`\theta`) of the product 3D array. + + + Args: + + L (int): Harmonic band-limit. + + spin (int): Harmonic spin. + + sampling (str): Sampling scheme. Supported sampling schemes include + {"mw", "mwss", "dh", "healpix"}. + + thetas (np.ndarray): Vector of sample positions in :math:`\theta` on the sphere. + + weights (np.ndarray): Vector of quadrature weights on the sphere. + + nside (int): HEALPix Nside resolution parameter. Only required + if sampling="healpix". Defaults to None. + + reality (bool): Whether the signal on the sphere is real. If so, + conjugate symmetry is exploited to reduce computational costs. + + L_lower (int): Harmonic lower-bound. Transform will only be computed + for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. + + m_offset (int): set to 1 if sampling in ["mwss", "healpix"], + otherwise set to 0. + + ftm (np.ndarray): fast Fourier transform matrix. + + els (np.ndarray): vector of `el` values + + elfactors (np.ndarray): vector of `el` factors, computed as + :math:`\sqrt{\frac{2el + 1}{4\pi}}` + + m_start_ind (int): column offset if reality = True (set to L-1) + + Returns: + np.ndarray: Spherical harmonic coefficients. + """ + + # phase shifts per theta + if sampling.lower() != "healpix": + phase_shifts = jnp.array([[1.0]]) + else: + phase_shifts = jax.vmap( + samples.ring_phase_shift_hp, + in_axes=(None, 0, None, None, None), + out_axes=-1, # theta along last dimension + )(L, jnp.arange(len(thetas)), nside, True, jnp) + + # dl array (double map along theta and el) + dl = jax.lax.map( + lambda el: jax.lax.map( + lambda theta: wigner.turok_jax.compute_slice(theta, el, L, -spin), thetas + ).T, + els, + ) + + # flm + flm = jnp.zeros( + (*samples.flm_shape(L), len(thetas)), # 3D array (L, 2L-1, ntheta) + dtype=jnp.complex128, + ) + flm = ( + flm.at[max(L_lower, abs(spin)) :, m_start_ind:, :] + .set( + weights + * elfactors[:, None, None] + * dl[:, m_start_ind:, :] + * ftm[:, m_start_ind + m_offset : 2 * L - 1 + m_offset, None].T + * phase_shifts[None, :, :] + ) + .sum(axis=-1) + ) + + return flm + + +def _compute_flm_map_scan( + L: int, + spin: int, + sampling: str, + thetas: np.ndarray, + weights: np.ndarray, + nside: int, + reality: bool, + L_lower: int, + m_offset: int, + ftm: np.ndarray, + els: np.ndarray, + elfactors: np.ndarray, + m_start_ind: int, +): + r"""Compute forward spherical harmonic transform using the 'map + scan' JAX method. + + We use the JAX `map` function to sweep across :math:`\ell` (`el`) and the JAX `lax.scan` + function to sweep across :math:`\theta`. + Specifically, we compute the complete Wigner-d matrix (`dl`) as a 2D array defined for each + :math:`\theta` value, by mapping the Turok & Bucher recursion along :math:`\ell` (`el`). + The spherical harmonic coefficients are computed as the 2D array `flm` that results + from scanning along :math:`\theta` values and accumulating the result of the product. + + Args: + + L (int): Harmonic band-limit. + + spin (int): Harmonic spin. + + sampling (str): Sampling scheme. Supported sampling schemes include + {"mw", "mwss", "dh", "healpix"}. + + thetas (np.ndarray): Vector of sample positions in :math:`\theta` on the sphere. + + weights (np.ndarray): Vector of quadrature weights on the sphere. + + nside (int): HEALPix Nside resolution parameter. Only required + if sampling="healpix". Defaults to None. + + reality (bool): Whether the signal on the sphere is real. If so, + conjugate symmetry is exploited to reduce computational costs. + + L_lower (int): Harmonic lower-bound. Transform will only be computed + for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. + + m_offset (int): set to 1 if sampling in ["mwss", "healpix"], + otherwise set to 0. + + ftm (np.ndarray): fast Fourier transform matrix. + + els (np.ndarray): vector of `el` values + + elfactors (np.ndarray): vector of `el` factors, computed as + :math:`\sqrt{\frac{2el + 1}{4\pi}}` + + m_start_ind (int): column offset if reality = True (set to L-1) + + Returns: + np.ndarray: Spherical harmonic coefficients. + """ + + # phase shift + if sampling.lower() != "healpix": + phase_shifts = jnp.ones_like(thetas) + else: + phase_shifts = jax.vmap( + samples.ring_phase_shift_hp, + in_axes=(None, 0, None, None, None), + out_axes=0, # theta along first dimension + )(L, jnp.arange(len(thetas)), nside, True, jnp) + + # flm (lax.scan over theta) + ts = jnp.arange(len(thetas)) + flm = jnp.zeros(samples.flm_shape(L), dtype=jnp.complex128) # (L, 2L-1) + + def accumulate(flm_carry, t_theta_weight_ftm_phase_shift): + t, theta, weight, ftm_slice, phase_shift = t_theta_weight_ftm_phase_shift + + # dl array (single map across el) + dl = jax.lax.map( + lambda el: wigner.turok_jax.compute_slice(theta, el, L, -spin), els + ) + + flm_carry = flm_carry.at[max(L_lower, abs(spin)) :, m_start_ind:].add( + weight + * elfactors[:, None] + * dl[:, m_start_ind:] + * ftm_slice[None, m_start_ind + m_offset : 2 * L - 1 + m_offset] + * phase_shift + ) + return flm_carry, None + + flm, _ = jax.lax.scan( + accumulate, + flm, + (ts, thetas, weights, ftm, phase_shifts), + ) + + return flm diff --git a/s2fft/wigner/turok_jax.py b/s2fft/wigner/turok_jax.py index af094477..c5b7a3d2 100644 --- a/s2fft/wigner/turok_jax.py +++ b/s2fft/wigner/turok_jax.py @@ -3,10 +3,12 @@ from jax import jit import jax.numpy as jnp from functools import partial +from warnings import warn -@partial(jit, static_argnums=(2, 3)) -def compute_slice(beta: float, el: int, L: int, mm: int) -> jnp.ndarray: +@partial(jit, static_argnums=(2, 3, 4)) +def compute_slice(beta: float, el: int, L: int, mm: int, +positive_m_only: bool = False) -> jnp.ndarray: r"""Compute a particular slice :math:`m^{\prime}`, denoted `mm`, of the complete Wigner-d matrix at polar angle :math:`\beta` using Turok & Bucher recursion. @@ -35,12 +37,39 @@ def compute_slice(beta: float, el: int, L: int, mm: int) -> jnp.ndarray: mm (int): Harmonic order at which to slice the matrix. + positive_m_only (bool, optional): Compute Wigner-d matrix for slice at m greater + than zero only. Defaults to False. + + Whether to exploit conjugate symmetry. By construction + this only leads to significant improvement for mm = 0. Defaults to False. + + Raises: + + Warning: If positive_m_only is true but mm not 0. + Returns: jnp.ndarray: Wigner-d matrix mm slice of dimension [2L-1]. """ - dl = jnp.zeros(2 * L - 1, dtype=jnp.float64) + if positive_m_only and mm != 0: + positive_m_only = False + + # Commenting the following warning for now, as its output is probably lost in the compiled version + # (side-effects don't get traced, see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions) + # + # warn( + # "Reality acceleration only supports spin 0 fields. " + # + "Defering to complex transform." + # ) + + dl = jnp.zeros( + 2 * L - 1, + dtype=jnp.float64) + dl = lax.cond( - jnp.abs(beta) < 1e-10, lambda x: _north_pole(x, el, L, mm), lambda x: x, dl + jnp.abs(beta) < 1e-10, + lambda x: _north_pole(x, el, L, mm), + lambda x: x, + dl ) dl = lax.cond( jnp.abs(beta - jnp.pi) < 1e-10, @@ -48,20 +77,24 @@ def compute_slice(beta: float, el: int, L: int, mm: int) -> jnp.ndarray: lambda x: x, dl, ) - dl = lax.cond(el == 0, lambda x: _el0(x, L), lambda x: x, dl) + dl = lax.cond( + el == 0, + lambda x: _el0(x, L), + lambda x: x, + dl + ) dl = lax.cond( jnp.any(dl), lambda x: x, - lambda x: _compute_quarter_slice(x, beta, el, L, mm), + lambda x: _compute_quarter_slice(x, beta, el, L, mm, positive_m_only), dl, ) - return reindex(dl, el, L, mm) -@partial(jit, static_argnums=(3, 4)) +@partial(jit, static_argnums=(3, 4,5)) def _compute_quarter_slice( - dl: jnp.array, beta: float, el: int, L: int, mm: int + dl: jnp.array, beta: float, el: int, L: int, mm: int, positive_m_only: bool = False ) -> jnp.ndarray: r"""Compute a single slice at :math:`m^{\prime}` of the Wigner-d matrix evaluated at :math:`\beta`. @@ -135,50 +168,50 @@ def log_first_row_iteration(log_first_row_i_minus_1, i): # Static array of indices for first dimension of dl array indices = jnp.arange(2 * L - 1) - for i in range(2): - sgn = (-1) ** (i) - - # Initialise the vector - dl = dl.at[lims[i]].set(1.0) - lamb = ((el + 1) * omc - half_slices[i] + c) / s - dl = dl.at[lims[i] + sgn].set(lamb * dl[lims[i]] * cpi[0]) - - def renorm_iteration(m, dl_lrenorm): - dl, lrenorm = dl_lrenorm - lamb = ((el + 1) * omc - half_slices[i] + m * c) / s - dl = dl.at[lims[i] + sgn * m].set( - lamb * cpi[m - 1] * dl[lims[i] + sgn * (m - 1)] - - cp2[m - 1] * dl[lims[i] + sgn * (m - 2)] - ) - condition = dl[lims[i] + sgn * m] > big_const - lrenorm = lax.cond( - condition, lambda x: x.at[i].add(-lbig), lambda x: x, lrenorm - ) - dl = lax.cond( - condition, - # multiply first m elements (if i == 0) or last m elements (if i == 1) - # of dl array by bigi - use jnp.where rather than directly updating - # array using 'in-place' update such as - # dl.at[lims[i]:lims[i] + sgn * (m + 1):sgn].multiply(bigi) - # to avoid non-static array slice (due to m dependence) that will raise - # an IndexError exception when used with lax.fori_loop - lambda x: jnp.where((indices < (m + 1))[::sgn], bigi * x, x), - lambda x: x, - dl - ) - return dl, lrenorm - - dl, lrenorm = lax.fori_loop(2, L, renorm_iteration, (dl, lrenorm)) - - # Apply renormalisation - renorm = sign[i] * jnp.exp(log_first_row[half_slices[i] - 1] - lrenorm[i]) - - if i == 0: - dl = dl.at[: L - 1].multiply(renorm) - - if i == 1: - dl = dl.at[-em].multiply((-1) ** ((mm - em + el + 1) % 2) * renorm) - + for i in ([1] if positive_m_only else range(2)): + + sgn = (-1) ** (i) + + # Initialise the vector + dl = dl.at[lims[i]].set(1.0) + lamb = ((el + 1) * omc - half_slices[i] + c) / s + dl = dl.at[lims[i] + sgn].set(lamb * dl[lims[i]] * cpi[0]) + + def renorm_iteration(m, dl_lrenorm): + dl, lrenorm = dl_lrenorm + lamb = ((el + 1) * omc - half_slices[i] + m * c) / s + dl = dl.at[lims[i] + sgn * m].set( + lamb * cpi[m - 1] * dl[lims[i] + sgn * (m - 1)] + - cp2[m - 1] * dl[lims[i] + sgn * (m - 2)] + ) + condition = dl[lims[i] + sgn * m] > big_const + lrenorm = lax.cond( + condition, lambda x: x.at[i].add(-lbig), lambda x: x, lrenorm + ) + dl = lax.cond( + condition, + # multiply first m elements (if i == 0) or last m elements (if i == 1) + # of dl array by bigi - use jnp.where rather than directly updating + # array using 'in-place' update such as + # dl.at[lims[i]:lims[i] + sgn * (m + 1):sgn].multiply(bigi) + # to avoid non-static array slice (due to m dependence) that will raise + # an IndexError exception when used with lax.fori_loop + lambda x: jnp.where((indices < (m + 1))[::sgn], bigi * x, x), + lambda x: x, + dl + ) + return dl, lrenorm + + dl, lrenorm = lax.fori_loop(2, L, renorm_iteration, (dl, lrenorm)) + + # Apply renormalisation + renorm = sign[i] * jnp.exp(log_first_row[half_slices[i] - 1] - lrenorm[i]) + + if i == 0: + dl = dl.at[: L - 1].multiply(renorm) + + if i == 1: + dl = dl.at[-em].multiply((-1) ** ((mm - em + el + 1) % 2) * renorm) return jnp.nan_to_num(dl, neginf=0, posinf=0) diff --git a/tests/test_transform.py b/tests/test_transform.py index 7ea6172b..014a1b8c 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -1,9 +1,12 @@ -import pytest +import healpy as hp import numpy as np -import s2fft as s2f import pyssht as ssht -import healpy as hp +import pytest +from jax.config import config + +import s2fft as s2f +config.update("jax_enable_x64", True) L_to_test = [6, 7, 8] L_lower_to_test = [0, 1, 2] @@ -11,7 +14,19 @@ nside_to_test = [2, 4, 8] L_to_nside_ratio = [2, 3] sampling_to_test = ["mw", "mwss", "dh"] -method_to_test = ["direct", "sov", "sov_fft", "sov_fft_vectorized"] +method_to_test = [ + "direct", + "sov", + "sov_fft", + "sov_fft_vectorized", +] +method_to_test_forward_only = [ + "jax_vmap_double", + "jax_vmap_scan", + "jax_vmap_loop", + "jax_map_double", + "jax_map_scan", +] reality_to_test = [False, True] @@ -74,7 +89,7 @@ def test_transform_inverse_healpix( @pytest.mark.parametrize("L_lower", L_lower_to_test) @pytest.mark.parametrize("spin", spin_to_test) @pytest.mark.parametrize("sampling", sampling_to_test) -@pytest.mark.parametrize("method", method_to_test) +@pytest.mark.parametrize("method", method_to_test + method_to_test_forward_only) @pytest.mark.parametrize("reality", reality_to_test) def test_transform_forward( flm_generator, @@ -104,7 +119,7 @@ def test_transform_forward( @pytest.mark.parametrize("nside", nside_to_test) @pytest.mark.parametrize("ratio", L_to_nside_ratio) -@pytest.mark.parametrize("method", method_to_test) +@pytest.mark.parametrize("method", method_to_test) # + method_to_test_forward_only) @pytest.mark.parametrize("reality", reality_to_test) def test_transform_forward_healpix( flm_generator, @@ -115,10 +130,17 @@ def test_transform_forward_healpix( ): sampling = "healpix" L = ratio * nside - flm = flm_generator(L=L, reality=True) - f = s2f.transform._inverse( - flm, L, sampling=sampling, method=method, nside=nside - ) + flm = flm_generator(L=L, reality=True) # should this be reality=reality? + + if method in ["direct", "sov", "sov_fft", "sov_fft_vectorized"]: + f = s2f.transform._inverse( + flm, L, sampling=sampling, method=method, nside=nside + ) + # use 'direct' for JAX approaches + else: + f = s2f.transform._inverse( + flm, L, sampling=sampling, method="direct", nside=nside + ) flm_direct = s2f.transform._forward( f, L, sampling=sampling, method=method, nside=nside, reality=reality @@ -130,6 +152,7 @@ def test_transform_forward_healpix( np.testing.assert_allclose(flm_direct_hp, flm_check, atol=1e-14) + @pytest.mark.parametrize("nside", nside_to_test) def test_healpix_nside_to_L_exceptions(flm_generator, nside: int): sampling = "healpix"