Skip to content

Commit

Permalink
Add utils.simulate, give parameter prior_variance to TimingModel, rel…
Browse files Browse the repository at this point in the history
…ocate utils.KernelMatrix
  • Loading branch information
vallis committed May 17, 2023
1 parent 3ad5205 commit ab7ef2b
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 90 deletions.
6 changes: 3 additions & 3 deletions enterprise/signals/gp_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from sksparse.cholmod import cholesky

from enterprise.signals import parameter, selections, signal_base, utils
from enterprise.signals.signal_base import KernelMatrix
from enterprise.signals.parameter import function
from enterprise.signals.selections import Selection
from enterprise.signals.utils import KernelMatrix

# logging.basicConfig(format="%(levelname)s: %(name)s: %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -225,11 +225,11 @@ def get_timing_model_basis(use_svd=False, normed=True):
return utils.unnormed_tm_basis()


def TimingModel(coefficients=False, name="linear_timing_model", use_svd=False, normed=True):
def TimingModel(coefficients=False, name="linear_timing_model", use_svd=False, normed=True, prior_variance=1e40):
"""Class factory for marginalized linear timing model signals."""

basis = get_timing_model_basis(use_svd, normed)
prior = utils.tm_prior()
prior = utils.tm_prior(variance=prior_variance)

BaseClass = BasisGP(prior, basis, coefficients=coefficients, name=name)

Expand Down
84 changes: 83 additions & 1 deletion enterprise/signals/signal_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from enterprise.signals.parameter import Function # noqa: F401
from enterprise.signals.parameter import function # noqa: F401
from enterprise.signals.parameter import ConstantParameter
from enterprise.signals.utils import KernelMatrix

from enterprise import __version__
from sys import version
Expand Down Expand Up @@ -1212,6 +1211,89 @@ def solve(self, other, left_array=None, logdet=False):
return (ret, self._get_logdet()) if logdet else ret


class KernelMatrix(np.ndarray):
def __new__(cls, init):
if isinstance(init, int):
ret = np.zeros(init, "d").view(cls)
else:
ret = init.view(cls)

if ret.ndim == 2:
ret._cliques = -1 * np.ones(ret.shape[0])
ret._clcount = 0

return ret

# see PTA._setcliques
def _setcliques(self, idxs):
allidx = set(self._cliques[idxs])
maxidx = max(allidx)

if maxidx == -1:
self._cliques[idxs] = self._clcount
self._clcount = self._clcount + 1
else:
self._cliques[idxs] = maxidx
if len(allidx) > 1:
self._cliques[np.in1d(self._cliques, allidx)] = maxidx

Check warning on line 1238 in enterprise/signals/signal_base.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/signal_base.py#L1238

Added line #L1238 was not covered by tests

def add(self, other, idx):
if other.ndim == 2 and self.ndim == 1:
self = KernelMatrix(np.diag(self))

if self.ndim == 1:
self[idx] += other
else:
if other.ndim == 1:
self[idx, idx] += other

Check warning on line 1248 in enterprise/signals/signal_base.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/signal_base.py#L1248

Added line #L1248 was not covered by tests
else:
self._setcliques(idx)
idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx)
self[idx] += other

return self

def set(self, other, idx):
if other.ndim == 2 and self.ndim == 1:
self = KernelMatrix(np.diag(self))

if self.ndim == 1:
self[idx] = other
else:
if other.ndim == 1:
self[idx, idx] = other

Check warning on line 1264 in enterprise/signals/signal_base.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/signal_base.py#L1264

Added line #L1264 was not covered by tests
else:
self._setcliques(idx)
idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx)
self[idx] = other

return self

def inv(self, logdet=False):
if self.ndim == 1:
inv = 1.0 / self

if logdet:
return inv, np.sum(np.log(self))
else:
return inv
else:
try:
cf = sl.cho_factor(self)
inv = sl.cho_solve(cf, np.identity(cf[0].shape[0]))
if logdet:
ld = 2.0 * np.sum(np.log(np.diag(cf[0])))
except np.linalg.LinAlgError:
u, s, v = np.linalg.svd(self)
inv = np.dot(u / s, u.T)
if logdet:
ld = np.sum(np.log(s))

Check warning on line 1290 in enterprise/signals/signal_base.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/signal_base.py#L1286-L1290

Added lines #L1286 - L1290 were not covered by tests
if logdet:
return inv, ld
else:
return inv


class ShermanMorrison(object):
"""Custom container class for Sherman-morrison array inversion."""

Expand Down
155 changes: 70 additions & 85 deletions enterprise/signals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import enterprise
from enterprise import constants as const
from enterprise import signals as sigs # noqa: F401
from enterprise.signals.signal_base import ShermanMorrison
from enterprise.signals.gp_bases import ( # noqa: F401
createfourierdesignmatrix_dm,
createfourierdesignmatrix_env,
Expand All @@ -31,6 +32,73 @@
logger = logging.getLogger(__name__)


def simulate(pta, params, sparse_cholesky=True):
"""Simulate residuals for all pulsars in `pta` by sampling all white-noise
and GP objects for parameters `params`. Requires GPs to have `combine=False`,
and will run faster with GP ECORR. If `pta` includes a `TimingModel`, that
should be created with a small `prior_variance`. This function can be used
with `utils.set_residuals` to replace residuals in a `Pulsar` object.
Note that any PTA built from that `Pulsar` may nevertheless cache residuals
internally, so it is safer to rebuild the PTA with the modified `Pulsar`."""

delays, ndiags, fmats, phis = (
pta.get_delay(params=params),
pta.get_ndiag(params=params),
pta.get_basis(params=params),
pta.get_phi(params=params),
)

gpresiduals = []
if pta._commonsignals:
if sparse_cholesky:
cf = cholesky(sps.csc_matrix(phis))
gp = np.zeros(phis.shape[0])
gp[cf.P()] = np.dot(cf.L().toarray(), np.random.randn(phis.shape[0]))
else:
gp = np.dot(sl.cholesky(phis, lower=True), np.random.randn(phis.shape[0]))

Check warning on line 58 in enterprise/signals/utils.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/utils.py#L58

Added line #L58 was not covered by tests

i = 0
for fmat in fmats:
j = i + fmat.shape[1]
gpresiduals.append(np.dot(fmat, gp[i:j]))
i = j

assert len(gp) == i
else:
for fmat, phi in zip(fmats, phis):
if phi is None:
gpresiduals.append(0)
elif phi.ndim == 1:
gpresiduals.append(np.dot(fmat, np.sqrt(phi) * np.random.randn(phi.shape[0])))

Check warning on line 72 in enterprise/signals/utils.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/utils.py#L68-L72

Added lines #L68 - L72 were not covered by tests
else:
raise NotImplementedError

Check warning on line 74 in enterprise/signals/utils.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/utils.py#L74

Added line #L74 was not covered by tests

whiteresiduals = []
for delay, ndiag in zip(delays, ndiags):
if ndiag is None:
whiteresiduals.append(0)

Check warning on line 79 in enterprise/signals/utils.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/utils.py#L79

Added line #L79 was not covered by tests
elif isinstance(ndiag, ShermanMorrison):
# this code is very slow...
n = np.diag(ndiag._nvec)
for j, s in zip(ndiag._jvec, ndiag._slices):
n[s, s] += j
whiteresiduals.append(delay + np.dot(sl.cholesky(n, lower=True), np.random.randn(n.shape[0])))

Check warning on line 85 in enterprise/signals/utils.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/utils.py#L82-L85

Added lines #L82 - L85 were not covered by tests
elif ndiag.ndim == 1:
whiteresiduals.append(delay + np.sqrt(ndiag) * np.random.randn(ndiag.shape[0]))
else:
raise NotImplementedError

Check warning on line 89 in enterprise/signals/utils.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/utils.py#L89

Added line #L89 was not covered by tests

return [np.array(g + w) for g, w in zip(gpresiduals, whiteresiduals)]


def set_residuals(psr, y):
if isinstance(psr, list):
for p, r in zip(psr, y):
p._residuals[p._isort] = r

Check warning on line 97 in enterprise/signals/utils.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/utils.py#L95-L97

Added lines #L95 - L97 were not covered by tests
else:
psr._residuals[psr._isort] = y

Check warning on line 99 in enterprise/signals/utils.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/utils.py#L99

Added line #L99 was not covered by tests


class ConditionalGP:
def __init__(self, pta, phiinv_method="cliques"):
"""This class allows the computation of conditional means and
Expand Down Expand Up @@ -208,89 +276,6 @@ def get_coefficients(pta, params, n=1, phiinv_method="cliques", variance=True, c
return ret[0] if n == 1 else ret


class KernelMatrix(np.ndarray):
def __new__(cls, init):
if isinstance(init, int):
ret = np.zeros(init, "d").view(cls)
else:
ret = init.view(cls)

if ret.ndim == 2:
ret._cliques = -1 * np.ones(ret.shape[0])
ret._clcount = 0

return ret

# see PTA._setcliques
def _setcliques(self, idxs):
allidx = set(self._cliques[idxs])
maxidx = max(allidx)

if maxidx == -1:
self._cliques[idxs] = self._clcount
self._clcount = self._clcount + 1
else:
self._cliques[idxs] = maxidx
if len(allidx) > 1:
self._cliques[np.in1d(self._cliques, allidx)] = maxidx

def add(self, other, idx):
if other.ndim == 2 and self.ndim == 1:
self = KernelMatrix(np.diag(self))

if self.ndim == 1:
self[idx] += other
else:
if other.ndim == 1:
self[idx, idx] += other
else:
self._setcliques(idx)
idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx)
self[idx] += other

return self

def set(self, other, idx):
if other.ndim == 2 and self.ndim == 1:
self = KernelMatrix(np.diag(self))

if self.ndim == 1:
self[idx] = other
else:
if other.ndim == 1:
self[idx, idx] = other
else:
self._setcliques(idx)
idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx)
self[idx] = other

return self

def inv(self, logdet=False):
if self.ndim == 1:
inv = 1.0 / self

if logdet:
return inv, np.sum(np.log(self))
else:
return inv
else:
try:
cf = sl.cho_factor(self)
inv = sl.cho_solve(cf, np.identity(cf[0].shape[0]))
if logdet:
ld = 2.0 * np.sum(np.log(np.diag(cf[0])))
except np.linalg.LinAlgError:
u, s, v = np.linalg.svd(self)
inv = np.dot(u / s, u.T)
if logdet:
ld = np.sum(np.log(s))
if logdet:
return inv, ld
else:
return inv


def create_stabletimingdesignmatrix(designmat, fastDesign=True):
"""
Stabilize the timing-model design matrix.
Expand Down Expand Up @@ -885,8 +870,8 @@ def svd_tm_basis(Mmat):


@function
def tm_prior(weights):
return weights * 1e40
def tm_prior(weights, variance=1e40):
return weights * variance


# Physical ephemeris model utility functions
Expand Down
25 changes: 24 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

import enterprise.constants as const
from enterprise.pulsar import Pulsar
from enterprise.signals import utils, parameter, signal_base, white_signals, gp_signals
from enterprise.signals import anis_coefficients as anis
from enterprise.signals import utils
from tests.enterprise_test_data import datadir


Expand All @@ -26,6 +26,7 @@ def setUpClass(cls):

# initialize Pulsar class
cls.psr = Pulsar(datadir + "/B1855+09_NANOGrav_9yv1.gls.par", datadir + "/B1855+09_NANOGrav_9yv1.tim")
cls.psr2 = Pulsar(datadir + "/J1909-3744_NANOGrav_9yv1.gls.par", datadir + "/J1909-3744_NANOGrav_9yv1.tim")

cls.F, _ = utils.createfourierdesignmatrix_red(cls.psr.toas, nmodes=30)

Expand All @@ -35,6 +36,28 @@ def setUpClass(cls):

cls.Mm = utils.create_stabletimingdesignmatrix(cls.psr.Mmat)

def test_simulate(self):
ef = white_signals.MeasurementNoise()

ec = gp_signals.EcorrBasisModel()

pl = utils.powerlaw(log10_A=parameter.Uniform(-16, -13), gamma=parameter.Uniform(1, 7))
orf = utils.hd_orf()
crn = gp_signals.FourierBasisCommonGP(pl, orf, components=20, name="GW")

m = ef + ec + crn

pta = signal_base.PTA([m(self.psr), m(self.psr2)])

ys = utils.simulate(pta, params=parameter.sample(pta.params))

msg = "Simulated residuals shape incorrect"
assert ys[0].shape == self.psr.residuals.shape, msg
assert ys[1].shape == self.psr2.residuals.shape, msg

msg = "Simulated residuals shape not a number"
assert np.all(~np.isnan(np.concatenate(ys))), msg

def test_createstabletimingdesignmatrix(self):
"""Timing model design matrix shape."""

Expand Down

0 comments on commit ab7ef2b

Please sign in to comment.