-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce utils.simulate() #348
base: dev
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
import enterprise | ||
from enterprise import constants as const | ||
from enterprise import signals as sigs # noqa: F401 | ||
from enterprise.signals.signal_base import ShermanMorrison | ||
from enterprise.signals.gp_bases import ( # noqa: F401 | ||
createfourierdesignmatrix_dm, | ||
createfourierdesignmatrix_env, | ||
|
@@ -31,6 +32,73 @@ | |
logger = logging.getLogger(__name__) | ||
|
||
|
||
def simulate(pta, params, sparse_cholesky=True): | ||
"""Simulate residuals for all pulsars in `pta` by sampling all white-noise | ||
and GP objects for parameters `params`. Requires GPs to have `combine=False`, | ||
and will run faster with GP ECORR. If `pta` includes a `TimingModel`, that | ||
should be created with a small `prior_variance`. This function can be used | ||
with `utils.set_residuals` to replace residuals in a `Pulsar` object. | ||
Note that any PTA built from that `Pulsar` may nevertheless cache residuals | ||
internally, so it is safer to rebuild the PTA with the modified `Pulsar`.""" | ||
|
||
delays, ndiags, fmats, phis = ( | ||
pta.get_delay(params=params), | ||
pta.get_ndiag(params=params), | ||
pta.get_basis(params=params), | ||
pta.get_phi(params=params), | ||
) | ||
|
||
gpresiduals = [] | ||
if pta._commonsignals: | ||
if sparse_cholesky: | ||
cf = cholesky(sps.csc_matrix(phis)) | ||
gp = np.zeros(phis.shape[0]) | ||
gp[cf.P()] = np.dot(cf.L().toarray(), np.random.randn(phis.shape[0])) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a #pragma: no cover |
||
gp = np.dot(sl.cholesky(phis, lower=True), np.random.randn(phis.shape[0])) | ||
|
||
i = 0 | ||
for fmat in fmats: | ||
j = i + fmat.shape[1] | ||
gpresiduals.append(np.dot(fmat, gp[i:j])) | ||
i = j | ||
|
||
assert len(gp) == i | ||
else: | ||
for fmat, phi in zip(fmats, phis): | ||
if phi is None: | ||
gpresiduals.append(0) | ||
elif phi.ndim == 1: | ||
gpresiduals.append(np.dot(fmat, np.sqrt(phi) * np.random.randn(phi.shape[0]))) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a # pragma: no cover here |
||
raise NotImplementedError | ||
|
||
whiteresiduals = [] | ||
for delay, ndiag in zip(delays, ndiags): | ||
if ndiag is None: | ||
whiteresiduals.append(0) | ||
elif isinstance(ndiag, ShermanMorrison): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a good idea to check for an instance of ShermanMorrison. When 'fastshermanmorrison' is used, this will be a different type. Instead, perhaps duck typing can be used, like:
|
||
# this code is very slow... | ||
n = np.diag(ndiag._nvec) | ||
for j, s in zip(ndiag._jvec, ndiag._slices): | ||
n[s, s] += j | ||
whiteresiduals.append(delay + np.dot(sl.cholesky(n, lower=True), np.random.randn(n.shape[0]))) | ||
elif ndiag.ndim == 1: | ||
whiteresiduals.append(delay + np.sqrt(ndiag) * np.random.randn(ndiag.shape[0])) | ||
else: | ||
raise NotImplementedError | ||
|
||
return [np.array(g + w) for g, w in zip(gpresiduals, whiteresiduals)] | ||
|
||
|
||
def set_residuals(psr, y): | ||
if isinstance(psr, list): | ||
for p, r in zip(psr, y): | ||
p._residuals[p._isort] = r | ||
else: | ||
psr._residuals[psr._isort] = y | ||
|
||
|
||
class ConditionalGP: | ||
def __init__(self, pta, phiinv_method="cliques"): | ||
"""This class allows the computation of conditional means and | ||
|
@@ -208,89 +276,6 @@ | |
return ret[0] if n == 1 else ret | ||
|
||
|
||
class KernelMatrix(np.ndarray): | ||
def __new__(cls, init): | ||
if isinstance(init, int): | ||
ret = np.zeros(init, "d").view(cls) | ||
else: | ||
ret = init.view(cls) | ||
|
||
if ret.ndim == 2: | ||
ret._cliques = -1 * np.ones(ret.shape[0]) | ||
ret._clcount = 0 | ||
|
||
return ret | ||
|
||
# see PTA._setcliques | ||
def _setcliques(self, idxs): | ||
allidx = set(self._cliques[idxs]) | ||
maxidx = max(allidx) | ||
|
||
if maxidx == -1: | ||
self._cliques[idxs] = self._clcount | ||
self._clcount = self._clcount + 1 | ||
else: | ||
self._cliques[idxs] = maxidx | ||
if len(allidx) > 1: | ||
self._cliques[np.in1d(self._cliques, allidx)] = maxidx | ||
|
||
def add(self, other, idx): | ||
if other.ndim == 2 and self.ndim == 1: | ||
self = KernelMatrix(np.diag(self)) | ||
|
||
if self.ndim == 1: | ||
self[idx] += other | ||
else: | ||
if other.ndim == 1: | ||
self[idx, idx] += other | ||
else: | ||
self._setcliques(idx) | ||
idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx) | ||
self[idx] += other | ||
|
||
return self | ||
|
||
def set(self, other, idx): | ||
if other.ndim == 2 and self.ndim == 1: | ||
self = KernelMatrix(np.diag(self)) | ||
|
||
if self.ndim == 1: | ||
self[idx] = other | ||
else: | ||
if other.ndim == 1: | ||
self[idx, idx] = other | ||
else: | ||
self._setcliques(idx) | ||
idx = (idx, idx) if isinstance(idx, slice) else (idx[:, None], idx) | ||
self[idx] = other | ||
|
||
return self | ||
|
||
def inv(self, logdet=False): | ||
if self.ndim == 1: | ||
inv = 1.0 / self | ||
|
||
if logdet: | ||
return inv, np.sum(np.log(self)) | ||
else: | ||
return inv | ||
else: | ||
try: | ||
cf = sl.cho_factor(self) | ||
inv = sl.cho_solve(cf, np.identity(cf[0].shape[0])) | ||
if logdet: | ||
ld = 2.0 * np.sum(np.log(np.diag(cf[0]))) | ||
except np.linalg.LinAlgError: | ||
u, s, v = np.linalg.svd(self) | ||
inv = np.dot(u / s, u.T) | ||
if logdet: | ||
ld = np.sum(np.log(s)) | ||
if logdet: | ||
return inv, ld | ||
else: | ||
return inv | ||
|
||
|
||
def create_stabletimingdesignmatrix(designmat, fastDesign=True): | ||
""" | ||
Stabilize the timing-model design matrix. | ||
|
@@ -885,8 +870,8 @@ | |
|
||
|
||
@function | ||
def tm_prior(weights): | ||
return weights * 1e40 | ||
def tm_prior(weights, variance=1e40): | ||
return weights * variance | ||
|
||
|
||
# Physical ephemeris model utility functions | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a # pragma: no cover here?