From 8db5716ae597954e123ac27b66a994d543bf2ef7 Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Tue, 4 Apr 2023 16:32:18 +0100 Subject: [PATCH 1/2] add interfacing for s2fft precompute transforms --- s2wav/transforms/__init__.py | 3 +- s2wav/transforms/jax_scattering.py | 274 ------------------- s2wav/transforms/jax_wavelets.py | 12 +- s2wav/transforms/jax_wavelets_precompute.py | 279 ++++++++++++++++++++ s2wav/transforms/numpy_scattering.py | 200 -------------- tests/test_wavelets_precompute.py | 151 +++++++++++ 6 files changed, 434 insertions(+), 485 deletions(-) delete mode 100644 s2wav/transforms/jax_scattering.py create mode 100644 s2wav/transforms/jax_wavelets_precompute.py delete mode 100644 s2wav/transforms/numpy_scattering.py create mode 100644 tests/test_wavelets_precompute.py diff --git a/s2wav/transforms/__init__.py b/s2wav/transforms/__init__.py index 7130eed..b0d7fe2 100644 --- a/s2wav/transforms/__init__.py +++ b/s2wav/transforms/__init__.py @@ -1,4 +1,3 @@ from . import numpy_wavelets -from . import numpy_scattering from . import jax_wavelets -from . import jax_scattering +from . import jax_wavelets_precompute \ No newline at end of file diff --git a/s2wav/transforms/jax_scattering.py b/s2wav/transforms/jax_scattering.py deleted file mode 100644 index e13f882..0000000 --- a/s2wav/transforms/jax_scattering.py +++ /dev/null @@ -1,274 +0,0 @@ -from jax import config - -config.update("jax_enable_x64", True) - -import jax.numpy as jnp -import jax.lax as lax -from typing import List, Tuple -from s2wav.transforms import jax_wavelets as wavelets -from s2wav.utils import shapes -import s2fft - - -def scatter( - f: jnp.ndarray, - L: int, - N: int = 1, - J_min: int = 0, - lam: float = 2.0, - nlayers: int = None, - sampling: str = "mw", - nside: int = None, - reality: bool = False, - filters: Tuple[jnp.ndarray] = None, - multiresolution: bool = False, - spmd: bool = False, - precomps: List[List[jnp.ndarray]] = None, -) -> List[jnp.ndarray]: - r"""Computes the scattering transform for descending by one nodes alone. - - Following equations outlined in section 3.2 of [1], recursively compute wavelet - transform each time passed through the activation function (in this case the absolute - value, 'modulus' operator) and store the scaling coefficients. - - Args: - f (np.ndarray): Signal :math:`f` on the sphere with shape :math:`[n_{\theta}, n_{\phi}]`. - - L (int): Harmonic bandlimit. - - N (int, optional): Upper azimuthal band-limit. Defaults to 1. - - J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. - - lam (float, optional): Wavelet parameter which determines the scale factor - between consecutive wavelet scales. Note that :math:`\lambda = 2` indicates - dyadic wavelets. Defaults to 2. - - nlayers (int, optional): Total number of scattering layers. Defaults to None, in - which case all paths which descend by 1 are included. - - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", - "healpix"}. Defaults to "mw". - - nside (int, optional): HEALPix Nside resolution parameter. Only required if - sampling="healpix". Defaults to None. - - reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits - conjugate symmetry of harmonic coefficients. Defaults to False. - - filters (Tuple[jnp.ndarray], optional): Precomputed wavelet filters. Defaults to None. - - multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` - resolution or its own resolution. Defaults to False. - - spmd (bool, optional): Whether to map compute over multiple devices. Currently this - only maps over all available devices, and is only valid for JAX implementations. - Defaults to False. - - precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most - of length :math:`L^2`, which is a minimal memory overhead. - - - Raises: - ValueError: Number of layers is larger than the number of available wavelet scales. - - NotImplementedError: Filters not provided, and functionality to compute these in - JAX is not yet implemented. - - Returns: - List[np.ndarray]: List of scattering coefficients. Dimensionality of each scattering - coefficien will depend on the selection of hyperparameters. In the most - typical case (J_min = 0), each scattering coefficient is a single scalar value. - - Notes: - [1] McEwen et al, Scattering networks on the sphere for scalable and - rotationally equivariant spherical CNNs, ICLR 2022. - """ - if precomps == None: - precomps = wavelets.generate_wigner_precomputes( - L, N, J_min, lam, sampling, nside, False, reality, multiresolution - ) - - scattering_coefficients = [] - J = shapes.j_max(L, lam) - - if nlayers is None: - nlayers = J - J_min - if nlayers > J - J_min: - raise ValueError( - f"Number of scattering layers {nlayers} is larger than the number of available wavelet scales {J-J_min}." - ) - if filters == None: - raise ValueError("Automatic filter computation not yet implemented!") - - # Weight filters a priori - wav_lm = jnp.einsum( - "jln, l->jln", - jnp.conj(filters[0]), - 8 * jnp.pi**2 / (2 * jnp.arange(L) + 1), - optimize=True, - ) - scal_l = filters[1] * jnp.sqrt(4 * jnp.pi / (2 * jnp.arange(L) + 1)) - - # Perform the first wavelet transform for all scales and directions - f_wav, f_scal = wavelets.analysis( - f, - L, - N, - J_min, - lam, - sampling=sampling, - nside=nside, - reality=reality, - multiresolution=multiresolution, - filters=filters, - spmd=spmd, - precomps=precomps, - scattering=True, - ) - scattering_coefficients.append(f_scal) - - # Perform the subsequent wavelet transforms for only J-1th scale. - j_iter = J_min - for layer in range(nlayers): - for j in range(j_iter, J + 1): - _, Nj, _ = shapes.LN_j(L, j - layer - 1, N, lam, multiresolution) - wavelet_coefficients = [] - for n in range(2 * Nj - 1): - temp, f_scal = _analysis_scattering( - f_wav[j - J_min][n], - L, - j - layer, - J_min, - lam, - sampling=sampling, - nside=nside, - reality=reality, - multiresolution=multiresolution, - filters=(wav_lm[j - J_min - layer - 1], scal_l), - precomps=precomps[j - J_min - layer - 1], - ) - wavelet_coefficients.append(temp[0]) - scattering_coefficients.append(f_scal) - - f_wav[j - J_min] = jnp.array(wavelet_coefficients) - - j_iter += 1 - - return jnp.array(scattering_coefficients) - - -def _analysis_scattering( - f: jnp.ndarray, - Lin: int, - j: int, - J_min: int = 0, - lam: float = 2.0, - spin: int = 0, - spin0: int = 0, - sampling: str = "mw", - nside: int = None, - reality: bool = False, - multiresolution: bool = False, - filters: Tuple[jnp.ndarray] = None, - precomps: List[List[jnp.ndarray]] = None, -) -> Tuple[jnp.ndarray, jnp.ndarray]: - r"""Wavelet analysis from pixel space to wavelet space for complex signals. - - Args: - f (np.ndarray): Signal :math:`f` on the sphere with shape :math:`[n_{\theta}, n_{\phi}]`. - - L (int): Harmonic bandlimit. - - j (int): Wavelet scale. - - N (int, optional): Upper azimuthal band-limit. Defaults to 1. - - J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. - - lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. - Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. - - spin (int, optional): Spin (integer) of input signal. Defaults to 0. - - spin0 (int, optional): Spin (integer) of output signal. Defaults to 0. - - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. Defaults to "mw". - - nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults - to None. - - reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits - conjugate symmetry of harmonic coefficients. Defaults to False. - - multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` - resolution or its own resolution. Defaults to False. - - filters (Tuple[jnp.ndarray], optional): Precomputed wavelet filters. Defaults to None. - - spmd (bool, optional): Whether to map compute over multiple devices. Currently this - only maps over all available devices, and is only valid for JAX implementations. - Defaults to False. - - precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most - of length :math:`L^2`, which is a minimal memory overhead. - - Returns: - f_wav (np.ndarray): Array of wavelet pixel-space coefficients - with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. - - f_scal (np.ndarray): Array of scaling pixel-space coefficients - with shape :math:`[n_{\theta}, n_{\phi}]`. - """ - L = shapes.wav_j_bandlimit(Lin, j, lam, multiresolution) - Ls = shapes.scal_bandlimit(L, J_min, lam, multiresolution) - flm = s2fft.forward_jax(f, L, spin, nside, sampling, reality) - - # Compute the scaling coefficients as usual. - phi = jnp.einsum( - "lm,l->lm", - flm[:Ls, L - Ls : L - 1 + Ls], - filters[1][:Ls], - optimize=True, - ) - - # Handle edge case - if Ls == 1: - f_scal = phi * jnp.sqrt(1 / (4 * jnp.pi)) - else: - f_scal = s2fft.inverse_jax(phi, Ls, spin, nside, sampling, reality) - - # Get shapes for scale j - 1. - Lj, Nj, L0j = shapes.LN_j(Lin, j - 1, 1, lam, multiresolution) - if j == J_min: - return ( - jnp.zeros((2 * Nj - 1, Lj, 2 * Lj - 1)), - jnp.real(f_scal) if reality else f_scal, - ) - - f_wav_lmn = shapes.construct_flmn_jax( - Lj, Nj, J_min, lam, multiresolution, True - ) - - # Only compute the wavelet coefficients for descending by 1. - f_wav_lmn = f_wav_lmn.at[::2, L0j:].set( - jnp.einsum( - "lm,ln->nlm", - flm[L0j:Lj, L - Lj : L - 1 + Lj], - filters[0][L0j:Lj, L - Nj : L - 1 + Nj : 2], - optimize=True, - ) - ) - f_wav = s2fft.wigner.inverse_jax( - f_wav_lmn, - Lj, - Nj, - nside, - sampling, - reality, - precomps, - spmd=False, - L_lower=L0j, - ) - - return jnp.abs(f_wav), jnp.real(f_scal) if reality else f_scal diff --git a/s2wav/transforms/jax_wavelets.py b/s2wav/transforms/jax_wavelets.py index 5bcc1ae..1cc6e15 100644 --- a/s2wav/transforms/jax_wavelets.py +++ b/s2wav/transforms/jax_wavelets.py @@ -63,7 +63,7 @@ def generate_wigner_precomputes( return precomps -@partial(jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13)) +# @partial(jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13)) def synthesis( f_wav: jnp.ndarray, f_scal: jnp.ndarray, @@ -177,7 +177,7 @@ def synthesis( return s2fft.inverse_jax(flm, L, spin, nside, sampling, reality) -@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12)) +# @partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12)) def analysis( f: jnp.ndarray, L: int, @@ -193,7 +193,6 @@ def analysis( filters: Tuple[jnp.ndarray] = None, spmd: bool = False, precomps: List[List[jnp.ndarray]] = None, - scattering: bool = False, ) -> Tuple[jnp.ndarray]: r"""Wavelet analysis from pixel space to wavelet space for complex signals. @@ -233,9 +232,6 @@ def analysis( precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most of length :math:`L^2`, which is a minimal memory overhead. - scattering (bool, optional): If using for scattering transform return absolute value - of scattering coefficients. - Returns: f_wav (jnp.ndarray): Array of wavelet pixel-space coefficients with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. @@ -293,8 +289,6 @@ def analysis( spmd_iter, L0j, ) - if scattering: - f_wav[j - J_min] = jnp.abs(f_wav[j - J_min]) # Project all harmonic coefficients for each lm onto scaling coefficients phi = filters[1][:Ls] * jnp.sqrt(4 * jnp.pi / (2 * jnp.arange(Ls) + 1)) @@ -306,7 +300,7 @@ def analysis( f_scal = temp * jnp.sqrt(1 / (4 * jnp.pi)) else: f_scal = s2fft.inverse_jax(temp, Ls, spin, nside, sampling, reality) - return f_wav, jnp.real(f_scal) if reality else f_scal + return f_wav, f_scal @partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9, 11)) diff --git a/s2wav/transforms/jax_wavelets_precompute.py b/s2wav/transforms/jax_wavelets_precompute.py new file mode 100644 index 0000000..af661ac --- /dev/null +++ b/s2wav/transforms/jax_wavelets_precompute.py @@ -0,0 +1,279 @@ +from jax import jit, config + +config.update("jax_enable_x64", True) + +import jax.numpy as jnp +from s2wav.utils import shapes +from functools import partial +from typing import Tuple, List +from s2fft.precompute_transforms.construct import wigner_kernel, spin_spherical_kernel +from s2fft.precompute_transforms import wigner, spherical + + +@partial(jit, static_argnums=(0, 1, 2, 3, 4, 5, 6, 7, 8)) +def generate_precomputes( + L: int, + N: int, + J_min: int = 0, + lam: float = 2.0, + sampling: str = "mw", + nside: int = None, + forward: bool = False, + reality: bool = False, + multiresolution: bool = False, +) -> List[jnp.ndarray]: + r"""Generates a list of precompute arrays associated with the underlying Wigner + transforms. + + Args: + L (int): Harmonic bandlimit. + + N (int, optional): Upper azimuthal band-limit. Defaults to 1. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 1. + + lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. + Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. + + sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", + "healpix"}. Defaults to "mw". + + nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults + to None. + + forward (bool, optional): _description_. Defaults to False. + + reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits + conjugate symmetry of harmonic coefficients. Defaults to False. + + multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` + resolution or its own resolution. Defaults to False. + + Returns: + List[jnp.ndarray]: Precomputed recursion arrays for underlying Wigner transforms. + """ + precomps = [] + J = shapes.j_max(L, lam) + for j in range(J_min, J + 1): + Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) + precomps.append( + wigner_kernel(Lj, Nj, reality, sampling, nside, forward) + ) + Ls = shapes.scal_bandlimit(L, J_min, lam, multiresolution) + precompute_scaling = spin_spherical_kernel(Ls, 0, reality, sampling, nside, forward) + precompute_full = spin_spherical_kernel(L, 0, reality, sampling, nside, not forward) + return precompute_full, precompute_scaling, precomps + + +def synthesis( + f_wav: jnp.ndarray, + f_scal: jnp.ndarray, + L: int, + N: int = 1, + J_min: int = 0, + lam: float = 2.0, + spin: int = 0, + spin0: int = 0, + sampling: str = "mw", + nside: int = None, + reality: bool = False, + multiresolution: bool = False, + filters: Tuple[jnp.ndarray] = None, + spmd: bool = False, + precomps: List[List[jnp.ndarray]] = None, +) -> jnp.ndarray: + r"""Computes the synthesis directional wavelet transform [1,2]. + Specifically, this transform synthesises the signal :math:`_{s}f(\omega) \in \mathbb{S}^2` by summing the contributions from wavelet and scaling coefficients in harmonic space, see equation 27 from `[2] `_. + Args: + f_wav (jnp.ndarray): Array of wavelet pixel-space coefficients + with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. + + f_scal (jnp.ndarray): Array of scaling pixel-space coefficients + with shape :math:`[n_{\theta}, n_{\phi}]`. + + L (int): Harmonic bandlimit. + + N (int, optional): Upper azimuthal band-limit. Defaults to 1. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 1. + + lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. + Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. + + spin (int, optional): Spin (integer) of input signal. Defaults to 0. + + spin0 (int, optional): Spin (integer) of output signal. Defaults to 0. + + sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", + "healpix"}. Defaults to "mw". + + nside (int, optional): HEALPix Nside resolution parameter. Only required if + sampling="healpix". Defaults to None. + + reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits + conjugate symmetry of harmonic coefficients. Defaults to False. + + multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` + resolution or its own resolution. Defaults to False. + + filters (Tuple[jnp.ndarray], optional): Precomputed wavelet filters. Defaults to None. + + spmd (bool, optional): Whether to map compute over multiple devices. Currently this + only maps over all available devices, and is only valid for JAX implementations. + Defaults to False. + + precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most + of length :math:`L^2`, which is a minimal memory overhead. + + Raises: + AssertionError: Shape of wavelet/scaling coefficients incorrect. + + Returns: + jnp.ndarray: Signal :math:`f` on the sphere with shape :math:`[n_{\theta}, n_{\phi}]`. + + Notes: + [1] B. Leidstedt et. al., "S2LET: A code to perform fast wavelet analysis on the sphere", A&A, vol. 558, p. A128, 2013. + [2] J. McEwen et. al., "Directional spin wavelets on the sphere", arXiv preprint arXiv:1509.06749 (2015). + """ + if precomps == None: + raise ValueError("Must provide precomputed kernels for this transform!") + + J = shapes.j_max(L, lam) + Ls = shapes.scal_bandlimit(L, J_min, lam, multiresolution) + flm = jnp.zeros((L, 2 * L - 1), dtype=jnp.complex128) + f_scal_lm = spherical.forward_transform_jax(f_scal, precomps[1], Ls, sampling, reality, spin, nside) + + # Sum the all wavelet wigner coefficients for each lmn + # Note that almost the entire compute is concentrated at the highest J + for j in range(J_min, J + 1): + Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) + spmd_iter = spmd if N == Nj else False + temp = wigner.forward_transform_jax(f_wav[j - J_min], precomps[2][j-J_min], Lj, Nj, sampling, reality, nside) + flm = flm.at[L0j:Lj, L - Lj : L - 1 + Lj].add( + jnp.einsum( + "ln,nlm->lm", + filters[0][j, L0j:Lj, L - Nj : L - 1 + Nj : 2], + temp[::2, L0j:, :], + optimize=True, + ) + ) + + # Sum the all scaling harmonic coefficients for each lm + phi = filters[1][:Ls] * jnp.sqrt(4 * jnp.pi / (2 * jnp.arange(Ls) + 1)) + flm = flm.at[:Ls, L - Ls : L - 1 + Ls].add( + jnp.einsum("lm,l->lm", f_scal_lm, phi, optimize=True) + ) + return spherical.inverse_transform_jax(flm, precomps[0], L, sampling, reality, spin, nside) + + +def analysis( + f: jnp.ndarray, + L: int, + N: int = 1, + J_min: int = 0, + lam: float = 2.0, + spin: int = 0, + spin0: int = 0, + sampling: str = "mw", + nside: int = None, + reality: bool = False, + multiresolution: bool = False, + filters: Tuple[jnp.ndarray] = None, + spmd: bool = False, + precomps: List[List[jnp.ndarray]] = None, +) -> Tuple[jnp.ndarray]: + r"""Wavelet analysis from pixel space to wavelet space for complex signals. + + Args: + f (jnp.ndarray): Signal :math:`f` on the sphere with shape :math:`[n_{\theta}, n_{\phi}]`. + + L (int): Harmonic bandlimit. + + N (int, optional): Upper azimuthal band-limit. Defaults to 1. + + J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. + + lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. + Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. + + spin (int, optional): Spin (integer) of input signal. Defaults to 0. + + spin0 (int, optional): Spin (integer) of output signal. Defaults to 0. + + sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. Defaults to "mw". + + nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults + to None. + + reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits + conjugate symmetry of harmonic coefficients. Defaults to False. + + multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` + resolution or its own resolution. Defaults to False. + + filters (Tuple[jnp.ndarray], optional): Precomputed wavelet filters. Defaults to None. + + spmd (bool, optional): Whether to map compute over multiple devices. Currently this + only maps over all available devices, and is only valid for JAX implementations. + Defaults to False. + + precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most + of length :math:`L^2`, which is a minimal memory overhead. + + Returns: + f_wav (jnp.ndarray): Array of wavelet pixel-space coefficients + with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. + + f_scal (jnp.ndarray): Array of scaling pixel-space coefficients + with shape :math:`[n_{\theta}, n_{\phi}]`. + """ + if precomps == None: + raise ValueError("Must provide precomputed kernels for this transform!") + + J = shapes.j_max(L, lam) + Ls = shapes.scal_bandlimit(L, J_min, lam, multiresolution) + + f_wav_lmn = shapes.construct_flmn_jax(L, N, J_min, lam, multiresolution) + f_wav = shapes.construct_f_jax( + L, N, J_min, lam, sampling, nside, multiresolution + ) + + wav_lm = jnp.einsum( + "jln, l->jln", + jnp.conj(filters[0]), + 8 * jnp.pi**2 / (2 * jnp.arange(L) + 1), + optimize=True, + ) + + flm = spherical.forward_transform_jax(f, precomps[0], L, sampling, reality, spin, nside) + # Project all wigner coefficients for each lmn onto wavelet coefficients + # Note that almost the entire compute is concentrated at the highest J + for j in range(J_min, J + 1): + Lj, Nj, L0j = shapes.LN_j(L, j, N, lam, multiresolution) + spmd_iter = spmd if N == Nj else False + f_wav_lmn[j - J_min] = ( + f_wav_lmn[j - J_min] + .at[::2, L0j:] + .add( + jnp.einsum( + "lm,ln->nlm", + flm[L0j:Lj, L - Lj : L - 1 + Lj], + wav_lm[j, L0j:Lj, L - Nj : L - 1 + Nj : 2], + optimize=True, + ) + ) + ) + + f_wav[j-J_min] = wigner.inverse_transform_jax(f_wav_lmn[j - J_min], precomps[2][j - J_min], Lj, Nj, sampling, reality, nside) + + # Project all harmonic coefficients for each lm onto scaling coefficients + phi = filters[1][:Ls] * jnp.sqrt(4 * jnp.pi / (2 * jnp.arange(Ls) + 1)) + temp = jnp.einsum( + "lm,l->lm", flm[:Ls, L - Ls : L - 1 + Ls], phi, optimize=True + ) + # Handle edge case + if Ls == 1: + f_scal = temp * jnp.sqrt(1 / (4 * jnp.pi)) + else: + f_scal = spherical.inverse_transform_jax(temp, precomps[1], Ls, sampling, reality, spin, nside) + return f_wav, f_scal diff --git a/s2wav/transforms/numpy_scattering.py b/s2wav/transforms/numpy_scattering.py deleted file mode 100644 index 185391f..0000000 --- a/s2wav/transforms/numpy_scattering.py +++ /dev/null @@ -1,200 +0,0 @@ -import numpy as np -from typing import List, Tuple -from s2wav.transforms import numpy_wavelets as wavelets -from s2wav.filter_factory import filters -from s2wav.utils import shapes -from s2fft import base_transforms as base - - -def scatter( - f: np.ndarray, - L: int, - N: int = 1, - J_min: int = 0, - lam: float = 2.0, - nlayers: int = None, - sampling: str = "mw", - nside: int = None, - reality: bool = False, - multiresolution: bool = False, - filter: Tuple[np.ndarray] = None, -) -> List[np.ndarray]: - """Computes the scattering transform for descending by one nodes alone. - - Following equations outlined in section 3.2 of [1], recursively compute wavelet - transform each time passed through the activation function (in this case the absolute - value, 'modulus' operator) and store the scaling coefficients. - - Args: - f (np.ndarray): _description_ - L (int): _description_ - N (int, optional): _description_. Defaults to 1. - J_min (int, optional): _description_. Defaults to 1. - lam (float, optional): _description_. Defaults to 2.0. - nlayers (int, optional): _description_. Defaults to None. - sampling (str, optional): _description_. Defaults to "mw". - nside (int, optional): _description_. Defaults to None. - reality (bool, optional): _description_. Defaults to False. - multiresolution (bool, optional): _description_. Defaults to False. - filters (Tuple[jnp.ndarray], optional): Precomputed wavelet filters. Defaults to None. - - Raises: - ValueError: Number of layers is larger than the number of available wavelet scales. - - Returns: - List[np.ndarray]: _description_ - - Notes: - [1] McEwen et al, Scattering networks on the sphere for scalable and - rotationally equivariant spherical CNNs, ICLR 2022. - """ - scattering_coefficients = [] - J = shapes.j_max(L, lam) - - if nlayers is None: - nlayers = J - J_min - if nlayers > J - J_min: - raise ValueError( - f"Number of scattering layers {nlayers} is larger than the number of available wavelet scales {J-J_min}." - ) - - if filter == None: - wav_lm, scal_l = filters.filters_directional_vectorised( - L, N, J_min, lam, 0, 0 - ) - wav_lm = np.einsum( - "jln, l->jln", - np.conj(wav_lm), - 8 * np.pi**2 / (2 * np.arange(L) + 1), - ) - - # Perform the first wavelet transform for all scales - f_wav, f_scal = wavelets.analysis( - f, - L, - N, - J_min, - lam, - sampling=sampling, - nside=nside, - reality=reality, - multiresolution=multiresolution, - scattering=True, - ) - scattering_coefficients.append(f_scal) - - # Perform the subsequent wavelet transforms for only J-1th scale. - j_iter = J_min - for layer in range(nlayers): - for j in range(j_iter, J + 1): - _, Nj, _ = shapes.LN_j(L, j - layer - 1, N, lam, multiresolution) - wavelet_coefficients = [] - for n in range(2 * Nj - 1): - temp, f_scal = wavelets._analysis_scattering( - f_wav[j - J_min][n], - L, - j - layer, - 1, - J_min, - lam, - sampling=sampling, - nside=nside, - reality=reality, - multiresolution=multiresolution, - filters=(wav_lm[j - J_min - layer - 1], scal_l), - ) - wavelet_coefficients.append(temp[0]) - scattering_coefficients.append(f_scal) - f_wav[j - J_min] = np.array(temp) - - j_iter += 1 - - return np.array(scattering_coefficients) - - -def _analysis_scattering( - f: np.ndarray, - Lin: int, - j: int, - N: int = 1, - J_min: int = 0, - lam: float = 2.0, - spin: int = 0, - spin0: int = 0, - sampling: str = "mw", - nside: int = None, - reality: bool = False, - multiresolution: bool = False, - filters: Tuple[np.ndarray] = None, -) -> Tuple[np.ndarray, np.ndarray]: - r"""Wavelet analysis from pixel space to wavelet space for complex signals. - - Args: - f (np.ndarray): Signal :math:`f` on the sphere with shape :math:`[n_{\theta}, n_{\phi}]`. - - L (int): Harmonic bandlimit. - - j (int): Wavelet scale. - - N (int, optional): Upper azimuthal band-limit. Defaults to 1. - - J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0. - - lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales. - Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2. - - spin (int, optional): Spin (integer) of input signal. Defaults to 0. - - spin0 (int, optional): Spin (integer) of output signal. Defaults to 0. - - sampling (str, optional): Spherical sampling scheme from {"mw","mwss", "dh", "healpix"}. Defaults to "mw". - - nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults - to None. - - reality (bool, optional): Whether :math:`f \in \mathbb{R}`, if True exploits - conjugate symmetry of harmonic coefficients. Defaults to False. - - multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}` - resolution or its own resolution. Defaults to False. - - filters (Tuple[jnp.ndarray], optional): Precomputed wavelet filters. Defaults to None. - - Returns: - f_wav (np.ndarray): Array of wavelet pixel-space coefficients - with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`. - - f_scal (np.ndarray): Array of scaling pixel-space coefficients - with shape :math:`[n_{\theta}, n_{\phi}]`. - """ - L = shapes.wav_j_bandlimit(Lin, j, lam, multiresolution) - Ls = shapes.scal_bandlimit(L, J_min, lam, multiresolution) - flm = base.spherical.forward(f, L, spin, sampling, nside, reality) - - # Compute the scaling coefficients as usual. - phi = filters[1][:Ls] * np.sqrt(4 * np.pi / (2 * np.arange(Ls) + 1)) - f_scal = base.spherical.inverse( - np.einsum("lm,l->lm", flm[:Ls, L - Ls : L - 1 + Ls], phi), - Ls, - spin, - sampling, - nside, - reality, - ) - # Get shapes for scale j - 1. - Lj, Nj, L0j = shapes.LN_j(Lin, j - 1, 1, lam, multiresolution) - if j == J_min: - return np.zeros((2 * Nj - 1, Lj, 2 * Lj - 1)), f_scal - - f_wav_lmn = shapes.construct_flmn(Lj, Nj, J_min, lam, multiresolution, True) - - # Only compute the wavelet coefficients for descending by 1. - f_wav_lmn[::2, L0j:] = np.einsum( - "lm,ln->nlm", - flm[L0j:Lj, L - Lj : L - 1 + Lj], - filters[0][L0j:Lj, L - Nj : L - 1 + Nj : 2], - ) - f_wav = base.wigner.inverse( - f_wav_lmn, Lj, Nj, L0j, sampling, reality, nside - ) - return np.abs(f_wav), f_scal diff --git a/tests/test_wavelets_precompute.py b/tests/test_wavelets_precompute.py new file mode 100644 index 0000000..ceb89e2 --- /dev/null +++ b/tests/test_wavelets_precompute.py @@ -0,0 +1,151 @@ +import pytest +import numpy as np +import pys2let as s2let + +from s2wav.transforms import jax_wavelets_precompute as jax_wavelets +from s2wav.filter_factory import filters +from s2wav.utils import shapes +from s2fft import base_transforms as base + +L_to_test = [8] +N_to_test = [2, 3] +J_min_to_test = [2] +lam_to_test = [2, 3] +multiresolution = [False, True] +reality = [False, True] +multiple_gpus = [False] +sampling_to_test = ["mw", "mwss", "dh"] + +@pytest.mark.parametrize("L", L_to_test) +@pytest.mark.parametrize("N", N_to_test) +@pytest.mark.parametrize("J_min", J_min_to_test) +@pytest.mark.parametrize("lam", lam_to_test) +@pytest.mark.parametrize("multiresolution", multiresolution) +@pytest.mark.parametrize("reality", reality) +@pytest.mark.parametrize("spmd", multiple_gpus) +def test_jax_synthesis( + wavelet_generator, + L: int, + N: int, + J_min: int, + lam: int, + multiresolution: bool, + reality: bool, + spmd: bool, +): + J = shapes.j_max(L, lam) + if J_min >= J: + pytest.skip("J_min larger than J which isn't a valid test case.") + + f_wav, f_scal, f_wav_s2let, f_scal_s2let = wavelet_generator( + L=L, + N=N, + J_min=J_min, + lam=lam, + multiresolution=multiresolution, + reality=reality, + ) + + f = s2let.synthesis_wav2px( + f_wav_s2let, + f_scal_s2let, + lam, + L, + J_min, + N, + spin=0, + upsample=not multiresolution, + ) + + # Precompute some values + filter = filters.filters_directional_vectorised(L, N, J_min, lam) + precomps = jax_wavelets.generate_precomputes( + L, + N, + J_min, + lam, + forward=True, + reality=reality, + multiresolution=multiresolution, + ) + f_check = jax_wavelets.synthesis( + f_wav, + f_scal, + L, + N, + J_min, + lam, + multiresolution=multiresolution, + reality=reality, + filters=filter, + precomps=precomps, + spmd=spmd, + ) + f = np.real(f) if reality else f + np.testing.assert_allclose(f, f_check.flatten("C"), atol=1e-14) + + +@pytest.mark.parametrize("L", L_to_test) +@pytest.mark.parametrize("N", N_to_test) +@pytest.mark.parametrize("J_min", J_min_to_test) +@pytest.mark.parametrize("lam", lam_to_test) +@pytest.mark.parametrize("multiresolution", multiresolution) +@pytest.mark.parametrize("reality", reality) +@pytest.mark.parametrize("spmd", multiple_gpus) +def test_jax_analysis( + flm_generator, + f_wav_converter, + L: int, + N: int, + J_min: int, + lam: int, + multiresolution: bool, + reality: bool, + spmd: bool, +): + J = shapes.j_max(L, lam) + if J_min >= J: + pytest.skip("J_min larger than J which isn't a valid test case.") + + flm = flm_generator(L=L, L_lower=0, spin=0, reality=reality) + f = base.spherical.inverse(flm, L, reality=reality) + + f_wav, f_scal = s2let.analysis_px2wav( + f.flatten("C").astype(np.complex128), + lam, + L, + J_min, + N, + spin=0, + upsample=not multiresolution, + ) + filter = filters.filters_directional_vectorised(L, N, J_min, lam) + precomps = jax_wavelets.generate_precomputes( + L, + N, + J_min, + lam, + forward=False, + reality=reality, + multiresolution=multiresolution, + ) + f_wav_check, f_scal_check = jax_wavelets.analysis( + f, + L, + N, + J_min, + lam, + multiresolution=multiresolution, + reality=reality, + filters=filter, + precomps=precomps, + spmd=spmd, + ) + + f_wav_check = f_wav_converter( + f_wav_check, L, N, J_min, lam, multiresolution + ) + + np.testing.assert_allclose(f_wav, f_wav_check, atol=1e-14) + np.testing.assert_allclose(f_scal, f_scal_check.flatten("C"), atol=1e-14) + From 0d905c81da8e8a0ee0f46c0c22537be16f6a77b9 Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Wed, 5 Apr 2023 08:47:20 +0100 Subject: [PATCH 2/2] update jax precompute functionality --- s2wav/transforms/jax_wavelets_precompute.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/s2wav/transforms/jax_wavelets_precompute.py b/s2wav/transforms/jax_wavelets_precompute.py index af661ac..7a320cc 100644 --- a/s2wav/transforms/jax_wavelets_precompute.py +++ b/s2wav/transforms/jax_wavelets_precompute.py @@ -64,7 +64,7 @@ def generate_precomputes( precompute_full = spin_spherical_kernel(L, 0, reality, sampling, nside, not forward) return precompute_full, precompute_scaling, precomps - +@partial(jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13)) def synthesis( f_wav: jnp.ndarray, f_scal: jnp.ndarray, @@ -165,7 +165,7 @@ def synthesis( ) return spherical.inverse_transform_jax(flm, precomps[0], L, sampling, reality, spin, nside) - +@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12)) def analysis( f: jnp.ndarray, L: int,