Skip to content

Commit

Permalink
Merge pull request #76 from astro-informatics/feature/precompute_tran…
Browse files Browse the repository at this point in the history
…sform

Feature/precompute transform
  • Loading branch information
CosmoMatt authored Sep 27, 2023
2 parents d82bec0 + 0d905c8 commit e0a1f8a
Show file tree
Hide file tree
Showing 6 changed files with 434 additions and 485 deletions.
3 changes: 1 addition & 2 deletions s2wav/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from . import numpy_wavelets
from . import numpy_scattering
from . import jax_wavelets
from . import jax_scattering
from . import jax_wavelets_precompute
274 changes: 0 additions & 274 deletions s2wav/transforms/jax_scattering.py

This file was deleted.

12 changes: 3 additions & 9 deletions s2wav/transforms/jax_wavelets.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def generate_wigner_precomputes(
return precomps


@partial(jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13))
# @partial(jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13))
def synthesis(
f_wav: jnp.ndarray,
f_scal: jnp.ndarray,
Expand Down Expand Up @@ -177,7 +177,7 @@ def synthesis(
return s2fft.inverse_jax(flm, L, spin, nside, sampling, reality)


@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12))
# @partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12))
def analysis(
f: jnp.ndarray,
L: int,
Expand All @@ -193,7 +193,6 @@ def analysis(
filters: Tuple[jnp.ndarray] = None,
spmd: bool = False,
precomps: List[List[jnp.ndarray]] = None,
scattering: bool = False,
) -> Tuple[jnp.ndarray]:
r"""Wavelet analysis from pixel space to wavelet space for complex signals.
Expand Down Expand Up @@ -233,9 +232,6 @@ def analysis(
precomps (List[jnp.ndarray]): Precomputed list of recursion coefficients. At most
of length :math:`L^2`, which is a minimal memory overhead.
scattering (bool, optional): If using for scattering transform return absolute value
of scattering coefficients.
Returns:
f_wav (jnp.ndarray): Array of wavelet pixel-space coefficients
with shape :math:`[n_{J}, 2N-1, n_{\theta}, n_{\phi}]`.
Expand Down Expand Up @@ -293,8 +289,6 @@ def analysis(
spmd_iter,
L0j,
)
if scattering:
f_wav[j - J_min] = jnp.abs(f_wav[j - J_min])

# Project all harmonic coefficients for each lm onto scaling coefficients
phi = filters[1][:Ls] * jnp.sqrt(4 * jnp.pi / (2 * jnp.arange(Ls) + 1))
Expand All @@ -306,7 +300,7 @@ def analysis(
f_scal = temp * jnp.sqrt(1 / (4 * jnp.pi))
else:
f_scal = s2fft.inverse_jax(temp, Ls, spin, nside, sampling, reality)
return f_wav, jnp.real(f_scal) if reality else f_scal
return f_wav, f_scal


@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9, 11))
Expand Down
Loading

0 comments on commit e0a1f8a

Please sign in to comment.