From a28c6db88fcfe4792f7bc2cd0711671915172f6a Mon Sep 17 00:00:00 2001 From: brandonyzhao Date: Tue, 30 Apr 2024 13:55:36 -0700 Subject: [PATCH] Fix manual setting of the 0 index for kaiser-squires (jax.ops deprecated) --- jax_lensing/inversion.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/jax_lensing/inversion.py b/jax_lensing/inversion.py index 72e5eff..869e3ca 100644 --- a/jax_lensing/inversion.py +++ b/jax_lensing/inversion.py @@ -57,8 +57,7 @@ def ks93(g1, g2): p1 = k1 * k1 - k2 * k2 p2 = 2 * k1 * k2 k2 = k1 * k1 + k2 * k2 - #k2[0, 0] = 1 # avoid division by 0 - k2 = jax.ops.index_update(k2, jax.ops.index[0, 0], 1.) # avoid division by 0 + k2 = k2.at[0, 0].set(1.) #avoid division by 0 kEhat = (p1 * g1hat + p2 * g2hat) / k2 kBhat = -(p2 * g1hat - p1 * g2hat) / k2 @@ -107,8 +106,7 @@ def ks93inv(kE, kB): p1 = k1 * k1 - k2 * k2 p2 = 2 * k1 * k2 k2 = k1 * k1 + k2 * k2 - #k2[0, 0] = 1 # avoid division by 0 - k2 = jax.ops.index_update(k2, jax.ops.index[0, 0], 1) # avoid division by 0 + k2 = k2.at[0, 0].set(1.) #avoid division by 0 g1hat = (p1 * kEhat - p2 * kBhat) / k2 g2hat = (p2 * kEhat + p1 * kBhat) / k2