Skip to content

Commit

Permalink
Merge pull request #82 from astro-informatics/feature/partial_analysi…
Browse files Browse the repository at this point in the history
…s_switch

bug fix: remove redundant memory allocation for partial analysis
  • Loading branch information
CosmoMatt authored Apr 13, 2024
2 parents 06f77b8 + b50a3e8 commit 5ebccdb
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 128 deletions.
162 changes: 48 additions & 114 deletions s2wav/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import math
from functools import partial
from typing import Tuple
from typing import Tuple, List
from s2fft.sampling import s2_samples, so3_samples
from scipy.special import loggamma
from jax.scipy.special import gammaln as jax_gammaln
Expand Down Expand Up @@ -232,116 +232,62 @@ def construct_f(
return f


@partial(jit, static_argnums=(0, 1, 2, 3, 4, 5, 6, 7))
@partial(jit, static_argnums=(0, 1, 2, 3))
def construct_f_jax(
L: int,
N: int = 1,
J_min: int = 0,
lam: float = 2.0,
sampling: str = "mw",
nside: int = None,
multiresolution: bool = False,
scattering: bool = False,
) -> jnp.ndarray:
"""Defines a list of arrays corresponding to f_wav.
J_max: int = None,
lam: float = 2.0
) -> List:
"""Defines a list corresponding to f_wav.
Args:
L (int): Harmonic bandlimit.
N (int, optional): Upper orientational band-limit. Defaults to 1.
J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0.
J_max (int, optional): Highest frequency wavelet scale to be used. Defaults to None.
lam (float, optional): Wavelet parameter which determines the scale factor between
consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic
wavelets. Defaults to 2.
sampling (str, optional): Spherical sampling scheme from
{"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw".
nside (int, optional): HEALPix Nside resolution parameter. Only required if
sampling="healpix". Defaults to None.
multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}`
resolution or its own resolution. Defaults to False.
scattering (bool, optional): Whether to create minimal arrays for scattering transform to
optimise for memory. Defaults to False.
Returns:
jnp.ndarray: Empty array (or list of empty arrays) in which to write data.
List: Empty list in which to write data.
"""
J = j_max(L, lam)
if scattering:
f = jnp.zeros(
f_wav_j(L, J - 1, N, lam, sampling, nside, multiresolution),
dtype=jnp.complex128,
)
else:
f = []
for j in range(J_min, J + 1):
f.append(
jnp.zeros(
f_wav_j(L, j, N, lam, sampling, nside, multiresolution),
dtype=jnp.complex128,
)
)
J = J_max if J_max is not None else j_max(L, lam)
f = []
for _ in range(J_min, J + 1):
f.append([])
return f


def construct_f_torch(
L: int,
N: int = 1,
J_min: int = 0,
lam: float = 2.0,
sampling: str = "mw",
nside: int = None,
multiresolution: bool = False,
scattering: bool = False,
) -> torch.tensor:
"""Defines a list of tensors corresponding to f_wav.
J_max: int = None,
lam: float = 2.0
) -> List:
"""Defines a list corresponding to f_wav.
Args:
L (int): Harmonic bandlimit.
N (int, optional): Upper orientational band-limit. Defaults to 1.
J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0.
J_max (int, optional): Highest frequency wavelet scale to be used. Defaults to None.
lam (float, optional): Wavelet parameter which determines the scale factor between
consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic
wavelets. Defaults to 2.
sampling (str, optional): Spherical sampling scheme from
{"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw".
nside (int, optional): HEALPix Nside resolution parameter. Only required if
sampling="healpix". Defaults to None.
multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}`
resolution or its own resolution. Defaults to False.
scattering (bool, optional): Whether to create minimal arrays for scattering transform to
optimise for memory. Defaults to False.
Returns:
torch.tensor: Empty tensor (or list of empty tensors) in which to write data.
List: Empty list in which to write data.
"""
J = j_max(L, lam)
if scattering:
f = torch.zeros(
f_wav_j(L, J - 1, N, lam, sampling, nside, multiresolution),
dtype=torch.complex128,
)
else:
f = []
for j in range(J_min, J + 1):
f.append(
torch.zeros(
f_wav_j(L, j, N, lam, sampling, nside, multiresolution),
dtype=torch.complex128,
)
)
J = J_max if J_max is not None else j_max(L, lam)
f = []
for _ in range(J_min, J + 1):
f.append([])
return f


Expand Down Expand Up @@ -537,9 +483,9 @@ def construct_flmn_jax(
L: int,
N: int = 1,
J_min: int = 0,
J_max: int = None,
lam: float = 2.0,
multiresolution: bool = False,
scattering: bool = False,
multiresolution: bool = False
) -> jnp.ndarray:
"""Defines a list of arrays corresponding to flmn.
Expand All @@ -550,43 +496,37 @@ def construct_flmn_jax(
J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0.
J_max (int, optional): Highest frequency wavelet scale to be used. Defaults to None.c
lam (float, optional): Wavelet parameter which determines the scale factor between
consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic
wavelets. Defaults to 2.
multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}`
resolution or its own resolution. Defaults to False.
scattering (bool, optional): Whether to create minimal arrays for scattering transform to
optimise for memory. Defaults to False.
Returns:
jnp.ndarray: Empty array (or list of empty arrays) in which to write data.
"""
J = j_max(L, lam)
if scattering:
flmn = jnp.zeros(
flmn_wav_j(L, J - 1, N, lam, multiresolution), dtype=jnp.complex128
)
else:
flmn = []
for j in range(J_min, J + 1):
flmn.append(
jnp.zeros(
flmn_wav_j(L, j, N, lam, multiresolution),
dtype=jnp.complex128,
)
J = J_max if J_max is not None else j_max(L, lam)
flmn = []
for j in range(J_min, J + 1):
flmn.append(
jnp.zeros(
flmn_wav_j(L, j, N, lam, multiresolution),
dtype=jnp.complex128,
)
)
return flmn


def construct_flmn_torch(
L: int,
N: int = 1,
J_min: int = 0,
J_max: int = None,
lam: float = 2.0,
multiresolution: bool = False,
scattering: bool = False,
multiresolution: bool = False
) -> torch.tensor:
"""Defines a list of tensors corresponding to flmn.
Expand All @@ -597,33 +537,27 @@ def construct_flmn_torch(
J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0.
J_max (int, optional): Highest frequency wavelet scale to be used. Defaults to None.c
lam (float, optional): Wavelet parameter which determines the scale factor between
consecutive wavelet scales. Note that :math:`\lambda = 2` indicates dyadic
wavelets. Defaults to 2.
multiresolution (bool, optional): Whether to store the scales at :math:`j_{\text{max}}`
resolution or its own resolution. Defaults to False.
scattering (bool, optional): Whether to create minimal arrays for scattering transform to
optimise for memory. Defaults to False.
Returns:
torch.tensor: Empty tensor (or list of empty tensors) in which to write data.
"""
J = j_max(L, lam)
if scattering:
flmn = torch.zeros(
flmn_wav_j(L, J - 1, N, lam, multiresolution), dtype=torch.complex128
)
else:
flmn = []
for j in range(J_min, J + 1):
flmn.append(
torch.zeros(
flmn_wav_j(L, j, N, lam, multiresolution),
dtype=torch.complex128,
)
J = J_max if J_max is not None else j_max(L, lam)
flmn = []
for j in range(J_min, J + 1):
flmn.append(
torch.zeros(
flmn_wav_j(L, j, N, lam, multiresolution),
dtype=torch.complex128,
)
)
return flmn


Expand Down
14 changes: 8 additions & 6 deletions s2wav/transforms/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def analysis(
J = samples.j_max(L, lam)
Ls = samples.scal_bandlimit(L, J_min, lam, True)

f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, lam, True)
f_wav = samples.construct_f_jax(L, N, J_min, lam, sampling, nside, True)
f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True)
f_wav = samples.construct_f_jax(L, J_min, J, lam)

wav_lm = jnp.einsum(
"jln, l->jln",
Expand Down Expand Up @@ -345,6 +345,8 @@ def flm_to_analysis(
J_min (int, optional): Lowest frequency wavelet scale to be used. Defaults to 0.
J_max (int, optional): Highest frequency wavelet scale to be used. Defaults to None.
lam (float, optional): Wavelet parameter which determines the scale factor between consecutive wavelet scales.
Note that :math:`\lambda = 2` indicates dyadic wavelets. Defaults to 2.
Expand Down Expand Up @@ -379,8 +381,8 @@ def flm_to_analysis(

J = J_max if J_max is not None else samples.j_max(L, lam)

f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, lam, True)
f_wav = samples.construct_f_jax(L, N, J_min, lam, sampling, nside, True)
f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True)
f_wav = samples.construct_f_jax(L, J_min, J, lam)

wav_lm = jnp.einsum(
"jln, l->jln",
Expand All @@ -406,9 +408,9 @@ def flm_to_analysis(
)
)

f_wav[j - J_min] = (
f_wav[j - J_min] = jnp.array(
s2fft.wigner.inverse(
f_wav_lmn[j - J_min],
jnp.array(f_wav_lmn[j - J_min]),
Lj,
Nj,
nside,
Expand Down
8 changes: 4 additions & 4 deletions s2wav/transforms/wavelet_precompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def analysis(
J = samples.j_max(L, lam)
Ls = samples.scal_bandlimit(L, J_min, lam, True)

f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, lam, True)
f_wav = samples.construct_f_jax(L, N, J_min, lam, sampling, nside, True)
f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True)
f_wav = samples.construct_f_jax(L, J_min, J, lam)

wav_lm = jnp.einsum(
"jln, l->jln",
Expand Down Expand Up @@ -279,8 +279,8 @@ def flm_to_analysis(

J = J_max if J_max is not None else samples.j_max(L, lam)

f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, lam, True)
f_wav = samples.construct_f_jax(L, N, J_min, lam, sampling, nside, True)
f_wav_lmn = samples.construct_flmn_jax(L, N, J_min, J, lam, True)
f_wav = samples.construct_f_jax(L, J_min, J, lam)

wav_lm = jnp.einsum(
"jln, l->jln",
Expand Down
8 changes: 4 additions & 4 deletions s2wav/transforms/wavelet_precompute_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def analysis(
J = samples.j_max(L, lam)
Ls = samples.scal_bandlimit(L, J_min, lam, True)

f_wav_lmn = samples.construct_flmn_torch(L, N, J_min, lam, True)
f_wav = samples.construct_f_torch(L, N, J_min, lam, sampling, nside, True)
f_wav_lmn = samples.construct_flmn_torch(L, N, J_min, J, lam, True)
f_wav = samples.construct_f_torch(L, J_min, J, lam)

wav_lm = torch.einsum(
"jln, l->jln",
Expand Down Expand Up @@ -265,8 +265,8 @@ def flm_to_analysis(

J = J_max if J_max is not None else samples.j_max(L, lam)

f_wav_lmn = samples.construct_flmn_torch(L, N, J_min, lam, True)
f_wav = samples.construct_f_torch(L, N, J_min, lam, sampling, nside, True)
f_wav_lmn = samples.construct_flmn_torch(L, N, J_min, J, lam, True)
f_wav = samples.construct_f_torch(L, J_min, J, lam)

wav_lm = torch.einsum(
"jln, l->jln",
Expand Down

0 comments on commit 5ebccdb

Please sign in to comment.