Skip to content

Commit

Permalink
compute distance moments from samples (#193)
Browse files Browse the repository at this point in the history
* compute distance moments from samples

* store distance moments as a part of the results

* add comment

* more comments
  • Loading branch information
deepchatterjeeligo authored Jan 8, 2025
1 parent d412707 commit 5a3bd68
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 0 deletions.
1 change: 1 addition & 0 deletions amplfi/train/models/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def test_step(self, batch, _):
descaled.cpu().numpy(),
parameters.cpu().numpy()[0],
)
result.calculate_distance_ansatz()
self.test_results.append(result)

# plot corner and skymap for a subset of the test results
Expand Down
98 changes: 98 additions & 0 deletions amplfi/train/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,69 @@
import healpy as hp
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp

"""Auxiliary functions for distance ansatz see:10.3847/2041-8205/829/1/L15"""


def P(x):
return np.exp(-0.5 * x**2) / np.sqrt(2 * np.pi)


def Q(x):
return sp.special.erfc(x / np.sqrt(2)) / 2


def H(x):
return P(x) / Q(x)


def dHdz(z):
return -H(z) * (H(z) - z)


def x2(z):
return z**2 + 1 + z * H(-z)


def x3(z):
return z**3 + 3 * z + (z**2 + 2) * H(-z)


def x4(z):
return z**4 + 6 * z**2 + 3 + (z**3 + 5 * z) * H(-z)


def x2_prime(z):
return 2 * z + H(-z) + z * dHdz(-z)


def x3_prime(z):
return 3 * z**2 + 3 + 2 * z * H(-z) + (z**2 + 2) * dHdz(-z)


def x4_prime(z):
return (
4 * z**3
+ 12 * z
+ (3 * z**2 + 5) * H(-z)
+ (z**3 + 5 * z) * dHdz(-z)
)


def f(z, s, m):
r = 1 + (s / m) ** 2
r *= x3(z) ** 2
r -= x2(z) * x4(z)
return r


def fprime(z, s, m):
r = 2 * (1 + (s / m) ** 2)
r *= x3(z) * x3_prime(z)
r -= x2(z) * x4_prime(z)
r -= x2_prime(z) * x4(z)
return r


class Result(bilby.result.Result):
Expand Down Expand Up @@ -74,3 +137,38 @@ def plot_mollview(self, nside: int, outpath: Path = None):
plt.savefig(outpath)

return fig

def get_dist_params(self):
"""Calculate d^2, d^3, d^4 moments from posterior samples.
Use them to obtain rho, m, s. This is not conditioned
per pixel."""
d = self.posterior["distance"]
# calculate moments
d_2 = d**2
rho = d_2.sum()
d_3 = d**3
d_3 = d_3.sum()
d_4 = d**4
d_4 = d_4.sum()

m = d_3 / rho
s = np.sqrt(d_4 / rho - m**2)
return rho, m, s

def calculate_distance_ansatz(self, maxiter=10):
"""Calculate the DISTMU, DISTSIGMA, DISTNORM parameters"""
rho, m, s = self.get_dist_params()
z0 = m / s
sol = sp.optimize.root_scalar(f, args=(s, m), fprime=fprime, x0=z0)
if not sol.converged:
self.dist_mu = 0
self.dist_sigma = float("inf")
self.dist_norm = 0
return
z_hat = sol.root
sigma = m * x2(z_hat) / x3(z_hat)
mu = sigma * z_hat
N = 1 / (Q(-z_hat) * sigma**2 * x2(z_hat))
self.dist_mu = mu
self.dist_sigma = sigma
self.norm = N

0 comments on commit 5a3bd68

Please sign in to comment.