diff --git a/s2wav/transforms/jax_wavelets_precompute.py b/s2wav/transforms/jax_wavelets_precompute.py index af661ac..7a320cc 100644 --- a/s2wav/transforms/jax_wavelets_precompute.py +++ b/s2wav/transforms/jax_wavelets_precompute.py @@ -64,7 +64,7 @@ def generate_precomputes( precompute_full = spin_spherical_kernel(L, 0, reality, sampling, nside, not forward) return precompute_full, precompute_scaling, precomps - +@partial(jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13)) def synthesis( f_wav: jnp.ndarray, f_scal: jnp.ndarray, @@ -165,7 +165,7 @@ def synthesis( ) return spherical.inverse_transform_jax(flm, precomps[0], L, sampling, reality, spin, nside) - +@partial(jit, static_argnums=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12)) def analysis( f: jnp.ndarray, L: int,