@@ -177,15 +177,15 @@ def jax_prng_key(seed=None):
177
177
@custom_jvp
178
178
def log1p_erf (x ):
179
179
x = jnp .array (x )
180
- return jnp .where (x < - 3 .0 , - x * x - jnp .log (- jnp .sqrt (np .pi )* x ) - 0.5 / (x * x ), jnp .log1p (jss .erf (x )))
180
+ return jnp .where (x < - 4 .0 , - x * x - jnp .log (- jnp .sqrt (np .pi )* x ) + 1 / ( x * x ) * ( - 0.5 + 1 / (x * x ) * ( 5.0 / 8.0 - 37.0 / ( 24.0 * x * x )) ), jnp .log1p (jss .erf (x )))
181
181
182
182
@log1p_erf .defjvp
183
183
def log1p_erf_jvp (primals , tangents ):
184
184
x , = primals
185
185
dx , = tangents
186
186
187
- ans = jnp .where (x < - 3 .0 , - x * x - jnp .log (- jnp .sqrt (np .pi )* x ) - 0.5 / (x * x ), jnp .log1p (jss .erf (x )))
188
- ans_dot = jnp .where (x < - 3 .0 , - 2 * x - 1 / x + 1 / (x * x * x ), 2 / jnp .sqrt (np .pi )* jnp .exp (- x * x )/ (1 + jss .erf (x )))
187
+ ans = jnp .where (x < - 4 .0 , - x * x - jnp .log (- jnp .sqrt (np .pi )* x ) + 1 / ( x * x ) * ( - 0.5 + 1 / (x * x ) * ( 5.0 / 8.0 - 37.0 / ( 24.0 * x * x )) ), jnp .log1p (jss .erf (x )))
188
+ ans_dot = jnp .where (x < - 4 .0 , - 2 * x + 1 / x * ( - 1 + 1 / ( x * x ) * ( 1 + 1 / (x * x ) * ( - 5.0 / 2.0 + 37.0 / ( 4.0 * x * x ))) ), 2 / jnp .sqrt (np .pi )* jnp .exp (- x * x )/ (1 + jss .erf (x )))
189
189
return ans , ans_dot * dx
190
190
191
191
def log_edge_normalization_factor (e , mu_e , sigma_e , e_obs , sigma_e_obs ):
@@ -200,20 +200,34 @@ def log_edge_normalization_factor(e, mu_e, sigma_e, e_obs, sigma_e_obs):
200
200
201
201
return log_numer - log_denom
202
202
203
- def edge_model (Aobs , sigma_obs , e_center_mu = 0.0 , e_center_sigma = 1.0 , c_mu = None , c_sigma = None , c_center = None , mu_bg = None , cov_bg = None , f_bg = None ):
203
+ def mean_sample (postfix , mean_vec , scale_vec ):
204
+ N = mean_vec .shape [0 ]
205
+ mu_unit = numpyro .sample ('mu_unit_' + postfix , dist .Normal (loc = 0 , scale = 1 ), sample_shape = (N ,))
206
+ mu = numpyro .deterministic ('mu_' + postfix , mean_vec + scale_vec * mu_unit )
207
+
208
+ return mu , mu_unit
209
+
210
+ def covariance_sample (postfix , scale_vec , eta = 1 ):
211
+ N = scale_vec .shape [0 ]
212
+
213
+ scale_unit = numpyro .sample ('scale_unit_' + postfix , dist .HalfNormal (scale = 1 ), sample_shape = (N ,))
214
+ scale = numpyro .deterministic ('scale_' + postfix , scale_unit * scale_vec )
215
+ corr_cholesky = numpyro .sample ('corr_cholesky_' + postfix , dist .LKJCholesky (N , eta ))
216
+ cov_cholesky = numpyro .deterministic ('cov_cholesky_' + postfix , scale [:,None ]* corr_cholesky )
217
+ cov = numpyro .deterministic ('cov_' + postfix , jnp .matmul (cov_cholesky , cov_cholesky .T ))
218
+
219
+ return cov , cov_cholesky , corr_cholesky , scale , scale_unit
220
+
221
+ def edge_model (Aobs , cov_obs , e_center_mu = 0.0 , e_center_sigma = 1.0 , c_mu = None , c_sigma = None , c_center = None , mu_bg = None , cov_bg = None , f_bg = None , nu_lkj = 1 ):
204
222
Aobs = np .array (Aobs )
205
- sigma_obs = np .array (sigma_obs )
223
+ cov_obs = np .array (cov_obs )
206
224
207
225
nobs , nband = Aobs .shape
208
- assert sigma_obs .shape == (nobs , nband ), 'size mismatch between `Aobs` and `sigma_obs `'
226
+ assert cov_obs .shape == (nobs , nband , nband ), 'size mismatch between `Aobs` and `cov_obs `'
209
227
210
228
A_mu = np .mean (Aobs , axis = 0 )
211
229
sigma_A = np .std (Aobs , axis = 0 )
212
230
213
- cov_obs = np .zeros ((nobs , nband , nband ))
214
- j ,k = np .diag_indices (nband )
215
- cov_obs [:,j ,k ] = np .square (sigma_obs )
216
-
217
231
if f_bg is None :
218
232
f_bg = numpyro .sample ('f_bg' , dist .Uniform ())
219
233
@@ -235,30 +249,20 @@ def edge_model(Aobs, sigma_obs, e_center_mu=0.0, e_center_sigma=1.0, c_mu=None,
235
249
e_centered = numpyro .deterministic ('e_centered' , e_center_mu + e_center_sigma * e_unit )
236
250
e = numpyro .deterministic ('e' , e_centered + jnp .dot (c , c_center ))
237
251
238
- mu_fg_unit = numpyro .sample ('mu_fg_unit' , dist .Normal (loc = 0 , scale = 1 ), sample_shape = (nband ,))
239
- mu_fg = numpyro .deterministic ('mu_fg' , mu_fg_unit * sigma_A + A_mu )
240
- scale_fg_unit = numpyro .sample ('scale_fg_unit' , dist .HalfNormal (scale = 1 ), sample_shape = (nband ,))
241
- scale_fg = numpyro .deterministic ('scale_fg' , scale_fg_unit * sigma_A )
242
- corr_fg_cholesky = numpyro .sample ('corr_fg_cholesky' , dist .LKJCholesky (nband , 3 ))
243
- cov_fg_cholesky = numpyro .deterministic ('cov_fg_cholesky' , scale_fg [:,None ]* corr_fg_cholesky )
244
- cov_fg = numpyro .deterministic ('cov_fg' , jnp .matmul (cov_fg_cholesky , cov_fg_cholesky .T ))
252
+ mu_fg , _ = mean_sample ('fg' , A_mu , sigma_A )
253
+ cov_fg , _ , _ , _ , _ = covariance_sample ('fg' , sigma_A , nu_lkj )
245
254
246
255
if mu_bg is None and cov_bg is None :
247
- mu_bg_unit = numpyro .sample ('mu_bg_offset' , dist .Normal (loc = 0 , scale = 1 ), sample_shape = (nband ,))
248
- mu_bg = numpyro .deterministic ('mu_bg' , mu_bg_unit * sigma_A + A_mu )
249
- scale_bg_unit = numpyro .sample ('scale_bg_unit' , dist .HalfNormal (scale = 1 ), sample_shape = (nband ,))
250
- scale_bg = numpyro .deterministic ('scale_bg' , scale_bg_unit * sigma_A )
251
- corr_bg_cholesky = numpyro .sample ('corr_bg_cholesky' , dist .LKJCholesky (nband , 3 ))
252
- cov_bg_cholesky = numpyro .deterministic ('cov_bg_cholesky' , scale_bg [:,None ]* corr_bg_cholesky )
253
- cov_bg = numpyro .deterministic ('cov_bg' , jnp .matmul (cov_bg_cholesky , cov_bg_cholesky .T ))
256
+ mu_bg , _ = mean_sample ('bg' , A_mu , sigma_A )
257
+ cov_bg , _ , _ , _ , _ = covariance_sample ('bg' , sigma_A , nu_lkj )
254
258
elif mu_bg is None or cov_bg is None :
255
259
raise ValueError ('either both `mu_bg` and `cov_bg` must be `None` or neither can be `None`' )
256
260
257
261
mu_e = jnp .dot (w , mu_fg )
258
262
e_obs = jnp .dot (Aobs , w )
259
- sigma_e2 = jnp .dot ( w , jnp . dot ( cov_fg , w ) )
263
+ sigma_e2 = jnp .sum ( w [:, None ] * cov_fg * w [ None ,:] )
260
264
sigma_e = jnp .sqrt (sigma_e2 )
261
- sigma_e_obs2 = jnp .sum (w [None ,:]* w [None ,:]* sigma_obs * sigma_obs , axis = 1 )
265
+ sigma_e_obs2 = jnp .sum (w [None ,:, None ]* w [None ,None , :]* cov_obs , axis = ( 1 , 2 ) )
262
266
sigma_e_obs = jnp .sqrt (sigma_e_obs2 )
263
267
264
268
log_alpha = numpyro .deterministic ('log_alpha' , log_edge_normalization_factor (e , mu_e , sigma_e , e_obs , sigma_e_obs ))
0 commit comments