Skip to content

Commit

Permalink
update jax precompute functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
CosmoMatt committed Apr 5, 2023
1 parent 8db5716 commit 0d905c8
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions s2wav/transforms/jax_wavelets_precompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def generate_precomputes(
precompute_full = spin_spherical_kernel(L, 0, reality, sampling, nside, not forward)
return precompute_full, precompute_scaling, precomps


@partial(jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13))
def synthesis(
f_wav: jnp.ndarray,
f_scal: jnp.ndarray,
Expand Down Expand Up @@ -165,7 +165,7 @@ def synthesis(
)
return spherical.inverse_transform_jax(flm, precomps[0], L, sampling, reality, spin, nside)


@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12))
def analysis(
f: jnp.ndarray,
L: int,
Expand Down

0 comments on commit 0d905c8

Please sign in to comment.