Skip to content

Commit efc234c

Browse files
author
Will M. Farr
committed
I am not insane---there was a problem with the error term.
1 parent 85c2980 commit efc234c

4 files changed

+1101
-246
lines changed

AmIInsane.ipynb

+317
Large diffs are not rendered by default.

GaussianTest.ipynb

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": 10,
15+
"execution_count": 19,
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
1919
"import arviz as az\n",
2020
"import edge_photometry as ep\n",
21+
"from jax.test_util import check_grads\n",
2122
"import jax.numpy as jnp\n",
23+
"import jax.scipy.special as jss\n",
2224
"import numpy as np\n",
2325
"import numpyro\n",
2426
"import numpyro.distributions as dist\n",

MockPhotometry.ipynb

+751-219
Large diffs are not rendered by default.

edge_photometry.py

+30-26
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,15 @@ def jax_prng_key(seed=None):
177177
@custom_jvp
178178
def log1p_erf(x):
179179
x = jnp.array(x)
180-
return jnp.where(x < -3.0, -x*x - jnp.log(-jnp.sqrt(np.pi)*x) - 0.5/(x*x), jnp.log1p(jss.erf(x)))
180+
return jnp.where(x < -4.0, -x*x - jnp.log(-jnp.sqrt(np.pi)*x) + 1/(x*x)*(-0.5 + 1/(x*x)*(5.0/8.0 - 37.0/(24.0*x*x))), jnp.log1p(jss.erf(x)))
181181

182182
@log1p_erf.defjvp
183183
def log1p_erf_jvp(primals, tangents):
184184
x, = primals
185185
dx, = tangents
186186

187-
ans = jnp.where(x < -3.0, -x*x - jnp.log(-jnp.sqrt(np.pi)*x) - 0.5/(x*x), jnp.log1p(jss.erf(x)))
188-
ans_dot = jnp.where(x < -3.0, -2*x - 1/x + 1/(x*x*x), 2/jnp.sqrt(np.pi)*jnp.exp(-x*x)/(1 + jss.erf(x)))
187+
ans = jnp.where(x < -4.0, -x*x - jnp.log(-jnp.sqrt(np.pi)*x) + 1/(x*x)*(-0.5 + 1/(x*x)*(5.0/8.0 - 37.0/(24.0*x*x))), jnp.log1p(jss.erf(x)))
188+
ans_dot = jnp.where(x < -4.0, -2*x + 1/x*(-1 + 1/(x*x)*(1 + 1/(x*x)*(-5.0/2.0 + 37.0/(4.0*x*x)))), 2/jnp.sqrt(np.pi)*jnp.exp(-x*x)/(1 + jss.erf(x)))
189189
return ans, ans_dot*dx
190190

191191
def log_edge_normalization_factor(e, mu_e, sigma_e, e_obs, sigma_e_obs):
@@ -200,20 +200,34 @@ def log_edge_normalization_factor(e, mu_e, sigma_e, e_obs, sigma_e_obs):
200200

201201
return log_numer - log_denom
202202

203-
def edge_model(Aobs, sigma_obs, e_center_mu=0.0, e_center_sigma=1.0, c_mu=None, c_sigma=None, c_center=None, mu_bg=None, cov_bg=None, f_bg=None):
203+
def mean_sample(postfix, mean_vec, scale_vec):
204+
N = mean_vec.shape[0]
205+
mu_unit = numpyro.sample('mu_unit_' + postfix, dist.Normal(loc=0, scale=1), sample_shape=(N,))
206+
mu = numpyro.deterministic('mu_' + postfix, mean_vec + scale_vec*mu_unit)
207+
208+
return mu, mu_unit
209+
210+
def covariance_sample(postfix, scale_vec, eta=1):
211+
N = scale_vec.shape[0]
212+
213+
scale_unit = numpyro.sample('scale_unit_' + postfix, dist.HalfNormal(scale=1), sample_shape=(N,))
214+
scale = numpyro.deterministic('scale_' + postfix, scale_unit * scale_vec)
215+
corr_cholesky = numpyro.sample('corr_cholesky_' + postfix, dist.LKJCholesky(N, eta))
216+
cov_cholesky = numpyro.deterministic('cov_cholesky_' + postfix, scale[:,None]*corr_cholesky)
217+
cov = numpyro.deterministic('cov_' + postfix, jnp.matmul(cov_cholesky, cov_cholesky.T))
218+
219+
return cov, cov_cholesky, corr_cholesky, scale, scale_unit
220+
221+
def edge_model(Aobs, cov_obs, e_center_mu=0.0, e_center_sigma=1.0, c_mu=None, c_sigma=None, c_center=None, mu_bg=None, cov_bg=None, f_bg=None, nu_lkj=1):
204222
Aobs = np.array(Aobs)
205-
sigma_obs = np.array(sigma_obs)
223+
cov_obs = np.array(cov_obs)
206224

207225
nobs, nband = Aobs.shape
208-
assert sigma_obs.shape == (nobs, nband), 'size mismatch between `Aobs` and `sigma_obs`'
226+
assert cov_obs.shape == (nobs, nband, nband), 'size mismatch between `Aobs` and `cov_obs`'
209227

210228
A_mu = np.mean(Aobs, axis=0)
211229
sigma_A = np.std(Aobs, axis=0)
212230

213-
cov_obs = np.zeros((nobs, nband, nband))
214-
j,k = np.diag_indices(nband)
215-
cov_obs[:,j,k] = np.square(sigma_obs)
216-
217231
if f_bg is None:
218232
f_bg = numpyro.sample('f_bg', dist.Uniform())
219233

@@ -235,30 +249,20 @@ def edge_model(Aobs, sigma_obs, e_center_mu=0.0, e_center_sigma=1.0, c_mu=None,
235249
e_centered = numpyro.deterministic('e_centered', e_center_mu + e_center_sigma*e_unit)
236250
e = numpyro.deterministic('e', e_centered + jnp.dot(c, c_center))
237251

238-
mu_fg_unit = numpyro.sample('mu_fg_unit', dist.Normal(loc=0, scale=1), sample_shape=(nband,))
239-
mu_fg = numpyro.deterministic('mu_fg', mu_fg_unit*sigma_A + A_mu)
240-
scale_fg_unit = numpyro.sample('scale_fg_unit', dist.HalfNormal(scale=1), sample_shape=(nband,))
241-
scale_fg = numpyro.deterministic('scale_fg', scale_fg_unit*sigma_A)
242-
corr_fg_cholesky = numpyro.sample('corr_fg_cholesky', dist.LKJCholesky(nband, 3))
243-
cov_fg_cholesky = numpyro.deterministic('cov_fg_cholesky', scale_fg[:,None]*corr_fg_cholesky)
244-
cov_fg = numpyro.deterministic('cov_fg', jnp.matmul(cov_fg_cholesky, cov_fg_cholesky.T))
252+
mu_fg, _ = mean_sample('fg', A_mu, sigma_A)
253+
cov_fg, _, _, _, _ = covariance_sample('fg', sigma_A, nu_lkj)
245254

246255
if mu_bg is None and cov_bg is None:
247-
mu_bg_unit = numpyro.sample('mu_bg_offset', dist.Normal(loc=0, scale=1), sample_shape=(nband,))
248-
mu_bg = numpyro.deterministic('mu_bg', mu_bg_unit*sigma_A + A_mu)
249-
scale_bg_unit = numpyro.sample('scale_bg_unit', dist.HalfNormal(scale=1), sample_shape=(nband,))
250-
scale_bg = numpyro.deterministic('scale_bg', scale_bg_unit*sigma_A)
251-
corr_bg_cholesky = numpyro.sample('corr_bg_cholesky', dist.LKJCholesky(nband, 3))
252-
cov_bg_cholesky = numpyro.deterministic('cov_bg_cholesky', scale_bg[:,None]*corr_bg_cholesky)
253-
cov_bg = numpyro.deterministic('cov_bg', jnp.matmul(cov_bg_cholesky, cov_bg_cholesky.T))
256+
mu_bg, _ = mean_sample('bg', A_mu, sigma_A)
257+
cov_bg, _, _, _, _ = covariance_sample('bg', sigma_A, nu_lkj)
254258
elif mu_bg is None or cov_bg is None:
255259
raise ValueError('either both `mu_bg` and `cov_bg` must be `None` or neither can be `None`')
256260

257261
mu_e = jnp.dot(w, mu_fg)
258262
e_obs = jnp.dot(Aobs, w)
259-
sigma_e2 = jnp.dot(w, jnp.dot(cov_fg, w))
263+
sigma_e2 = jnp.sum(w[:,None]*cov_fg*w[None,:])
260264
sigma_e = jnp.sqrt(sigma_e2)
261-
sigma_e_obs2 = jnp.sum(w[None,:]*w[None,:]*sigma_obs*sigma_obs, axis=1)
265+
sigma_e_obs2 = jnp.sum(w[None,:,None]*w[None,None,:]*cov_obs, axis=(1, 2))
262266
sigma_e_obs = jnp.sqrt(sigma_e_obs2)
263267

264268
log_alpha = numpyro.deterministic('log_alpha', log_edge_normalization_factor(e, mu_e, sigma_e, e_obs, sigma_e_obs))

0 commit comments

Comments
 (0)