Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce utils.simulate() #348

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions enterprise/signals/gp_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from sksparse.cholmod import cholesky

from enterprise.signals import parameter, selections, signal_base, utils
from enterprise.signals.signal_base import KernelMatrix
from enterprise.signals.parameter import function
from enterprise.signals.selections import Selection
from enterprise.signals.utils import KernelMatrix

# logging.basicConfig(format="%(levelname)s: %(name)s: %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -225,11 +225,11 @@ def get_timing_model_basis(use_svd=False, normed=True):
return utils.unnormed_tm_basis()


def TimingModel(coefficients=False, name="linear_timing_model", use_svd=False, normed=True):
def TimingModel(coefficients=False, name="linear_timing_model", use_svd=False, normed=True, prior_variance=1e40):
"""Class factory for marginalized linear timing model signals."""

basis = get_timing_model_basis(use_svd, normed)
prior = utils.tm_prior()
prior = utils.tm_prior(variance=prior_variance)

BaseClass = BasisGP(prior, basis, coefficients=coefficients, name=name)

Expand Down Expand Up @@ -848,6 +848,7 @@ def MNMMNF(self, T):

# we're ignoring logdet = True for two-dimensional cases, but OK
def solve(self, right, left_array=None, logdet=False):
# compute generalized version of r+ N^-1 r
if right.ndim == 1 and left_array is right:
res = right

Expand All @@ -856,11 +857,13 @@ def solve(self, right, left_array=None, logdet=False):
MNr = self.MNr(res)
ret = rNr - np.dot(MNr, self.cf(MNr))
return (ret, logdet_N + self.cf.logdet() + self.Mprior) if logdet else ret
# compute generalized version of T+ N^-1 r
elif right.ndim == 1 and left_array is not None and left_array.ndim == 2:
res, T = right, left_array

TNr = self.Nmat.solve(res, left_array=T)
return TNr - np.tensordot(self.MNMMNF(T), self.MNr(res), (0, 0))
# compute generalized version of T+ N^-1 T
elif right.ndim == 2 and left_array is right:
T = right

Expand Down
84 changes: 83 additions & 1 deletion enterprise/signals/signal_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from enterprise.signals.parameter import Function # noqa: F401
from enterprise.signals.parameter import function # noqa: F401
from enterprise.signals.parameter import ConstantParameter
from enterprise.signals.utils import KernelMatrix

from enterprise import __version__
from sys import version
Expand Down Expand Up @@ -1212,6 +1211,89 @@
return (ret, self._get_logdet()) if logdet else ret


class KernelMatrix(np.ndarray):
def __new__(cls, init):
if isinstance(init, int):
ret = np.zeros(init, "d").view(cls)
else:
ret = init.view(cls)

if ret.ndim == 2:
ret._cliques = -1 * np.ones(ret.shape[0])
ret._clcount = 0

return ret

# see PTA._setcliques
def _setcliques(self, idxs):
allidx = set(self._cliques[idxs])
maxidx = max(allidx)

if maxidx == -1:
self._cliques[idxs] = self._clcount
self._clcount = self._clcount + 1
else:
self._cliques[idxs] = maxidx
if len(allidx) > 1:
self._cliques[np.in1d(self._cliques, allidx)] = maxidx

Check warning on line 1238 in enterprise/signals/signal_base.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/signal_base.py#L1238

Added line #L1238 was not covered by tests

def add(self, other, idx):
if other.ndim == 2 and self.ndim == 1:
self = KernelMatrix(np.diag(self))

if self.ndim == 1:
self[idx] += other
else:
if other.ndim == 1:
self[idx, idx] += other

Check warning on line 1248 in enterprise/signals/signal_base.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/signal_base.py#L1248

Added line #L1248 was not covered by tests
else:
self._setcliques(idx)
idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx)
self[idx] += other

return self

def set(self, other, idx):
if other.ndim == 2 and self.ndim == 1:
self = KernelMatrix(np.diag(self))

if self.ndim == 1:
self[idx] = other
else:
if other.ndim == 1:
self[idx, idx] = other

Check warning on line 1264 in enterprise/signals/signal_base.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/signal_base.py#L1264

Added line #L1264 was not covered by tests
else:
self._setcliques(idx)
idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx)
self[idx] = other

return self

def inv(self, logdet=False):
if self.ndim == 1:
inv = 1.0 / self

if logdet:
return inv, np.sum(np.log(self))
else:
return inv
else:
try:
cf = sl.cho_factor(self)
inv = sl.cho_solve(cf, np.identity(cf[0].shape[0]))
if logdet:
ld = 2.0 * np.sum(np.log(np.diag(cf[0])))
except np.linalg.LinAlgError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a # pragma: no cover here?

u, s, v = np.linalg.svd(self)
inv = np.dot(u / s, u.T)
if logdet:
ld = np.sum(np.log(s))

Check warning on line 1290 in enterprise/signals/signal_base.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/signal_base.py#L1286-L1290

Added lines #L1286 - L1290 were not covered by tests
if logdet:
return inv, ld
else:
return inv


class ShermanMorrison(object):
"""Custom container class for Sherman-morrison array inversion."""

Expand Down
155 changes: 70 additions & 85 deletions enterprise/signals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import enterprise
from enterprise import constants as const
from enterprise import signals as sigs # noqa: F401
from enterprise.signals.signal_base import ShermanMorrison
from enterprise.signals.gp_bases import ( # noqa: F401
createfourierdesignmatrix_dm,
createfourierdesignmatrix_env,
Expand All @@ -31,6 +32,73 @@
logger = logging.getLogger(__name__)


def simulate(pta, params, sparse_cholesky=True):
"""Simulate residuals for all pulsars in `pta` by sampling all white-noise
and GP objects for parameters `params`. Requires GPs to have `combine=False`,
and will run faster with GP ECORR. If `pta` includes a `TimingModel`, that
should be created with a small `prior_variance`. This function can be used
with `utils.set_residuals` to replace residuals in a `Pulsar` object.
Note that any PTA built from that `Pulsar` may nevertheless cache residuals
internally, so it is safer to rebuild the PTA with the modified `Pulsar`."""

delays, ndiags, fmats, phis = (
pta.get_delay(params=params),
pta.get_ndiag(params=params),
pta.get_basis(params=params),
pta.get_phi(params=params),
)

gpresiduals = []
if pta._commonsignals:
if sparse_cholesky:
cf = cholesky(sps.csc_matrix(phis))
gp = np.zeros(phis.shape[0])
gp[cf.P()] = np.dot(cf.L().toarray(), np.random.randn(phis.shape[0]))
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a #pragma: no cover

gp = np.dot(sl.cholesky(phis, lower=True), np.random.randn(phis.shape[0]))

Check warning on line 58 in enterprise/signals/utils.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/utils.py#L58

Added line #L58 was not covered by tests

i = 0
for fmat in fmats:
j = i + fmat.shape[1]
gpresiduals.append(np.dot(fmat, gp[i:j]))
i = j

assert len(gp) == i
else:
for fmat, phi in zip(fmats, phis):
if phi is None:
gpresiduals.append(0)
elif phi.ndim == 1:
gpresiduals.append(np.dot(fmat, np.sqrt(phi) * np.random.randn(phi.shape[0])))

Check warning on line 72 in enterprise/signals/utils.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/utils.py#L68-L72

Added lines #L68 - L72 were not covered by tests
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a # pragma: no cover here

raise NotImplementedError

Check warning on line 74 in enterprise/signals/utils.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/utils.py#L74

Added line #L74 was not covered by tests

whiteresiduals = []
for delay, ndiag in zip(delays, ndiags):
if ndiag is None:
whiteresiduals.append(0)

Check warning on line 79 in enterprise/signals/utils.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/utils.py#L79

Added line #L79 was not covered by tests
elif isinstance(ndiag, ShermanMorrison):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a good idea to check for an instance of ShermanMorrison. When 'fastshermanmorrison' is used, this will be a different type. Instead, perhaps duck typing can be used, like:

if all(hasattr(ndiag, attr) for attr in ['_nvec', '_jvec', '_slices']):

# this code is very slow...
n = np.diag(ndiag._nvec)
for j, s in zip(ndiag._jvec, ndiag._slices):
n[s, s] += j
whiteresiduals.append(delay + np.dot(sl.cholesky(n, lower=True), np.random.randn(n.shape[0])))

Check warning on line 85 in enterprise/signals/utils.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/utils.py#L82-L85

Added lines #L82 - L85 were not covered by tests
elif ndiag.ndim == 1:
whiteresiduals.append(delay + np.sqrt(ndiag) * np.random.randn(ndiag.shape[0]))
else:
raise NotImplementedError

Check warning on line 89 in enterprise/signals/utils.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/utils.py#L89

Added line #L89 was not covered by tests

return [np.array(g + w) for g, w in zip(gpresiduals, whiteresiduals)]


def set_residuals(psr, y):
if isinstance(psr, list):
for p, r in zip(psr, y):
p._residuals[p._isort] = r

Check warning on line 97 in enterprise/signals/utils.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/utils.py#L95-L97

Added lines #L95 - L97 were not covered by tests
else:
psr._residuals[psr._isort] = y

Check warning on line 99 in enterprise/signals/utils.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/utils.py#L99

Added line #L99 was not covered by tests


class ConditionalGP:
def __init__(self, pta, phiinv_method="cliques"):
"""This class allows the computation of conditional means and
Expand Down Expand Up @@ -208,89 +276,6 @@
return ret[0] if n == 1 else ret


class KernelMatrix(np.ndarray):
def __new__(cls, init):
if isinstance(init, int):
ret = np.zeros(init, "d").view(cls)
else:
ret = init.view(cls)

if ret.ndim == 2:
ret._cliques = -1 * np.ones(ret.shape[0])
ret._clcount = 0

return ret

# see PTA._setcliques
def _setcliques(self, idxs):
allidx = set(self._cliques[idxs])
maxidx = max(allidx)

if maxidx == -1:
self._cliques[idxs] = self._clcount
self._clcount = self._clcount + 1
else:
self._cliques[idxs] = maxidx
if len(allidx) > 1:
self._cliques[np.in1d(self._cliques, allidx)] = maxidx

def add(self, other, idx):
if other.ndim == 2 and self.ndim == 1:
self = KernelMatrix(np.diag(self))

if self.ndim == 1:
self[idx] += other
else:
if other.ndim == 1:
self[idx, idx] += other
else:
self._setcliques(idx)
idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx)
self[idx] += other

return self

def set(self, other, idx):
if other.ndim == 2 and self.ndim == 1:
self = KernelMatrix(np.diag(self))

if self.ndim == 1:
self[idx] = other
else:
if other.ndim == 1:
self[idx, idx] = other
else:
self._setcliques(idx)
idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx)
self[idx] = other

return self

def inv(self, logdet=False):
if self.ndim == 1:
inv = 1.0 / self

if logdet:
return inv, np.sum(np.log(self))
else:
return inv
else:
try:
cf = sl.cho_factor(self)
inv = sl.cho_solve(cf, np.identity(cf[0].shape[0]))
if logdet:
ld = 2.0 * np.sum(np.log(np.diag(cf[0])))
except np.linalg.LinAlgError:
u, s, v = np.linalg.svd(self)
inv = np.dot(u / s, u.T)
if logdet:
ld = np.sum(np.log(s))
if logdet:
return inv, ld
else:
return inv


def create_stabletimingdesignmatrix(designmat, fastDesign=True):
"""
Stabilize the timing-model design matrix.
Expand Down Expand Up @@ -885,8 +870,8 @@


@function
def tm_prior(weights):
return weights * 1e40
def tm_prior(weights, variance=1e40):
return weights * variance


# Physical ephemeris model utility functions
Expand Down
25 changes: 24 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

import enterprise.constants as const
from enterprise.pulsar import Pulsar
from enterprise.signals import utils, parameter, signal_base, white_signals, gp_signals
from enterprise.signals import anis_coefficients as anis
from enterprise.signals import utils
from tests.enterprise_test_data import datadir


Expand All @@ -26,6 +26,7 @@ def setUpClass(cls):

# initialize Pulsar class
cls.psr = Pulsar(datadir + "/B1855+09_NANOGrav_9yv1.gls.par", datadir + "/B1855+09_NANOGrav_9yv1.tim")
cls.psr2 = Pulsar(datadir + "/J1909-3744_NANOGrav_9yv1.gls.par", datadir + "/J1909-3744_NANOGrav_9yv1.tim")

cls.F, _ = utils.createfourierdesignmatrix_red(cls.psr.toas, nmodes=30)

Expand All @@ -35,6 +36,28 @@ def setUpClass(cls):

cls.Mm = utils.create_stabletimingdesignmatrix(cls.psr.Mmat)

def test_simulate(self):
ef = white_signals.MeasurementNoise()

ec = gp_signals.EcorrBasisModel()

pl = utils.powerlaw(log10_A=parameter.Uniform(-16, -13), gamma=parameter.Uniform(1, 7))
orf = utils.hd_orf()
crn = gp_signals.FourierBasisCommonGP(pl, orf, components=20, name="GW")

m = ef + ec + crn

pta = signal_base.PTA([m(self.psr), m(self.psr2)])

ys = utils.simulate(pta, params=parameter.sample(pta.params))

msg = "Simulated residuals shape incorrect"
assert ys[0].shape == self.psr.residuals.shape, msg
assert ys[1].shape == self.psr2.residuals.shape, msg

msg = "Simulated residuals shape not a number"
assert np.all(~np.isnan(np.concatenate(ys))), msg

def test_createstabletimingdesignmatrix(self):
"""Timing model design matrix shape."""

Expand Down