Skip to content

Commit

Permalink
Merge pull request #356 from vhaasteren/quant2ind_returns_indices
Browse files Browse the repository at this point in the history
Allow for index arrays in ShermanMorrison
  • Loading branch information
AaronDJohnson committed Nov 17, 2023
2 parents 2415762 + b9a3bcc commit d782c29
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 75 deletions.
69 changes: 36 additions & 33 deletions enterprise/signals/signal_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from enterprise.signals.parameter import function # noqa: F401
from enterprise.signals.parameter import ConstantParameter
from enterprise.signals.utils import KernelMatrix
from enterprise.signals.utils import indices_from_slice

from enterprise import __version__
from sys import version
Expand Down Expand Up @@ -1118,6 +1119,7 @@ class BlockMatrix(object):
def __init__(self, blocks, slices, nvec=0):
self._blocks = blocks
self._slices = slices
self._idxs = [indices_from_slice(slc) for slc in slices]
self._nvec = nvec

if np.any(nvec != 0):
Expand Down Expand Up @@ -1152,15 +1154,15 @@ def _solve_ZNX(self, X, Z):
ZNXr = np.dot(Z[self._idx, :].T, X[self._idx, :] / self._nvec[self._idx, None])
else:
ZNXr = 0
for slc, block in zip(self._slices, self._blocks):
Zblock = Z[slc, :]
Xblock = X[slc, :]
for idx, block in zip(self._idxs, self._blocks):
Zblock = Z[idx, :]
Xblock = X[idx, :]

if slc.stop - slc.start > 1:
cf = sl.cho_factor(block + np.diag(self._nvec[slc]))
if len(idx) > 1:
cf = sl.cho_factor(block + np.diag(self._nvec[idx]))
bx = sl.cho_solve(cf, Xblock)
else:
bx = Xblock / self._nvec[slc][:, None]
bx = Xblock / self._nvec[idx][:, None]
ZNX += np.dot(Zblock.T, bx)
ZNX += ZNXr
return ZNX.squeeze() if len(ZNX) > 1 else float(ZNX)
Expand All @@ -1173,11 +1175,11 @@ def _solve_NX(self, X):
X = X.reshape(X.shape[0], 1)

NX = X / self._nvec[:, None]
for slc, block in zip(self._slices, self._blocks):
Xblock = X[slc, :]
if slc.stop - slc.start > 1:
cf = sl.cho_factor(block + np.diag(self._nvec[slc]))
NX[slc] = sl.cho_solve(cf, Xblock)
for idx, block in zip(self._idxs, self._blocks):
Xblock = X[idx, :]
if len(idx) > 1:
cf = sl.cho_factor(block + np.diag(self._nvec[idx]))
NX[idx] = sl.cho_solve(cf, Xblock)
return NX.squeeze()

def _get_logdet(self):
Expand All @@ -1188,12 +1190,12 @@ def _get_logdet(self):
logdet = np.sum(np.log(self._nvec[self._idx]))
else:
logdet = 0
for slc, block in zip(self._slices, self._blocks):
if slc.stop - slc.start > 1:
cf = sl.cho_factor(block + np.diag(self._nvec[slc]))
for idx, block in zip(self._idxs, self._blocks):
if len(idx) > 1:
cf = sl.cho_factor(block + np.diag(self._nvec[idx]))
logdet += np.sum(2 * np.log(np.diag(cf[0])))
else:
logdet += np.sum(np.log(self._nvec[slc]))
logdet += np.sum(np.log(self._nvec[idx]))
return logdet

def solve(self, other, left_array=None, logdet=False):
Expand All @@ -1218,6 +1220,7 @@ class ShermanMorrison(object):
def __init__(self, jvec, slices, nvec=0.0):
self._jvec = jvec
self._slices = slices
self._idxs = [indices_from_slice(slc) for slc in slices]
self._nvec = nvec

def __add__(self, other):
Expand All @@ -1235,12 +1238,12 @@ def _solve_D1(self, x):
"""Solves :math:`N^{-1}x` where :math:`x` is a vector."""

Nx = x / self._nvec
for slc, jv in zip(self._slices, self._jvec):
if slc.stop - slc.start > 1:
rblock = x[slc]
niblock = 1 / self._nvec[slc]
for idx, jv in zip(self._idxs, self._jvec):
if len(idx) > 1:
rblock = x[idx]
niblock = 1 / self._nvec[idx]
beta = 1.0 / (np.einsum("i->", niblock) + 1.0 / jv)
Nx[slc] -= beta * np.dot(niblock, rblock) * niblock
Nx[idx] -= beta * np.dot(niblock, rblock) * niblock
return Nx

def _solve_1D1(self, x, y):
Expand All @@ -1250,11 +1253,11 @@ def _solve_1D1(self, x, y):

Nx = x / self._nvec
yNx = np.dot(y, Nx)
for slc, jv in zip(self._slices, self._jvec):
if slc.stop - slc.start > 1:
xblock = x[slc]
yblock = y[slc]
niblock = 1 / self._nvec[slc]
for idx, jv in zip(self._idxs, self._jvec):
if len(idx) > 1:
xblock = x[idx]
yblock = y[idx]
niblock = 1 / self._nvec[idx]
beta = 1.0 / (np.einsum("i->", niblock) + 1.0 / jv)
yNx -= beta * np.dot(niblock, xblock) * np.dot(niblock, yblock)
return yNx
Expand All @@ -1265,11 +1268,11 @@ def _solve_2D2(self, X, Z):
"""

ZNX = np.dot(Z.T / self._nvec, X)
for slc, jv in zip(self._slices, self._jvec):
if slc.stop - slc.start > 1:
Zblock = Z[slc, :]
Xblock = X[slc, :]
niblock = 1 / self._nvec[slc]
for idx, jv in zip(self._idxs, self._jvec):
if len(idx) > 1:
Zblock = Z[idx, :]
Xblock = X[idx, :]
niblock = 1 / self._nvec[idx]
beta = 1.0 / (np.einsum("i->", niblock) + 1.0 / jv)
zn = np.dot(niblock, Zblock)
xn = np.dot(niblock, Xblock)
Expand All @@ -1281,9 +1284,9 @@ def _get_logdet(self):
is a quantization matrix.
"""
logdet = np.einsum("i->", np.log(self._nvec))
for slc, jv in zip(self._slices, self._jvec):
if slc.stop - slc.start > 1:
niblock = 1 / self._nvec[slc]
for idx, jv in zip(self._idxs, self._jvec):
if len(idx) > 1:
niblock = 1 / self._nvec[idx]
beta = 1.0 / (np.einsum("i->", niblock) + 1.0 / jv)
logdet += np.log(jv) - np.log(beta)
return logdet
Expand Down
25 changes: 18 additions & 7 deletions enterprise/signals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,26 +767,37 @@ def create_quantization_matrix(toas, dt=1, nmin=2):
return U, weights


def quant2ind(U):
def quant2ind(U, as_slice=False):
"""
Use quantization matrix to return slices of non-zero elements.
Use quantization matrix to return indices of non-zero elements.
:param U: quantization matrix
:param as_slice: whether to return a slice object
:return: list of `slice`s for non-zero elements of U
:return: list of `slice`s or indices for non-zero elements of U
.. note:: This function assumes that the pulsar TOAs were sorted by time.
.. note:: For slice objects the TOAs need to be sorted by time
"""
inds = []
for cc, col in enumerate(U.T):
epinds = np.flatnonzero(col)
if epinds[-1] - epinds[0] + 1 != len(epinds):
raise ValueError("ERROR: TOAs not sorted properly!")
inds.append(slice(epinds[0], epinds[-1] + 1))
if epinds[-1] - epinds[0] + 1 != len(epinds) or not as_slice:
inds.append(epinds)
else:
inds.append(slice(epinds[0], epinds[-1] + 1))
return inds


def indices_from_slice(slc):
"""Given a slice object, return an index arrays"""

if isinstance(slc, np.ndarray):
return slc
else:
return np.arange(*slc.indices(slc.stop))


def linear_interp_basis(toas, dt=30 * 86400):
"""Provides a basis for linear interpolation.
Expand Down
37 changes: 20 additions & 17 deletions enterprise/signals/white_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from enterprise.signals import parameter, selections, signal_base, utils
from enterprise.signals.parameter import function
from enterprise.signals.selections import Selection
from enterprise.signals.utils import indices_from_slice

try:
import fastshermanmorrison.fastshermanmorrison as fastshermanmorrison
Expand Down Expand Up @@ -217,13 +218,18 @@ def __init__(self, psr):
nepoch = sum(U.shape[1] for U in Umats)
U = np.zeros((len(psr.toas), nepoch))
self._slices = {}
self._idxs = {}
netot = 0
for ct, (key, mask) in enumerate(zip(keys, masks)):
nn = Umats[ct].shape[1]
U[mask, netot : nn + netot] = Umats[ct]
self._slices.update({key: utils.quant2ind(U[:, netot : nn + netot])})
netot += nn

self._idxs.update(
{key: [indices_from_slice(slc) for slc in slices] for (key, slices) in self._slices.items()}
)

# initialize sparse matrix
self._setup(psr)

Expand Down Expand Up @@ -252,17 +258,17 @@ def _setup(self, psr):

def _setup_sparse(self, psr):
Ns = scipy.sparse.csc_matrix((len(psr.toas), len(psr.toas)))
for key, slices in self._slices.items():
for slc in slices:
if slc.stop - slc.start > 1:
Ns[slc, slc] = 1.0
for key, idxs in self._idxs.items():
for idx in idxs:
if len(idx) > 1:
Ns[np.ix_(idx, idx)] = 1.0
self._Ns = signal_base.csc_matrix_alt(Ns)

def _get_ndiag_sparse(self, params):
for p in self._params:
for slc in self._slices[p]:
if slc.stop - slc.start > 1:
self._Ns[slc, slc] = 10 ** (2 * self.get(p, params))
for idx in self._idxs[p]:
if len(idx) > 1:
self._Ns[np.ix_(idx, idx)] = 10 ** (2 * self.get(p, params))
return self._Ns

def _get_ndiag_sherman_morrison(self, params):
Expand All @@ -274,21 +280,18 @@ def _get_ndiag_fast_sherman_morrison(self, params):
return fastshermanmorrison.FastShermanMorrison(jvec, slices)

def _get_ndiag_block(self, params):
slices, jvec = self._get_jvecs(params)
idxs, jvec = self._get_jvecs(params)
blocks = []
for jv, slc in zip(jvec, slices):
nb = slc.stop - slc.start
for jv, idx in zip(jvec, idxs):
nb = len(idx)
blocks.append(np.ones((nb, nb)) * jv)
return signal_base.BlockMatrix(blocks, slices)
return signal_base.BlockMatrix(blocks, idxs)

def _get_jvecs(self, params):
slices = sum([self._slices[key] for key in sorted(self._slices.keys())], [])
idxs = sum([self._idxs[key] for key in sorted(self._idxs.keys())], [])
jvec = np.concatenate(
[
np.ones(len(self._slices[key])) * 10 ** (2 * self.get(key, params))
for key in sorted(self._slices.keys())
]
[np.ones(len(self._idxs[key])) * 10 ** (2 * self.get(key, params)) for key in sorted(self._idxs.keys())]
)
return (slices, jvec)
return (idxs, jvec)

return EcorrKernelNoise
21 changes: 21 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,27 @@ def test_quantization_matrix(self):
assert U.shape == (4005, 235), msg1
assert all(np.sum(U, axis=0) > 1), msg2

inds = utils.quant2ind(U, as_slice=False)
slcs = utils.quant2ind(U, as_slice=True)
inds_check = [utils.indices_from_slice(slc) for slc in slcs]

msg3 = "Quantization Matrix slice not equal to quantization indices"
for ind, ind_c in zip(inds, inds_check):
assert np.all(ind == ind_c), msg3

def test_indices_from_slice(self):
"""Test conversion of slices to numpy indices"""
ind_np = np.array([2, 4, 6, 8])
ind_np_check = utils.indices_from_slice(ind_np)

msg1 = "Numpy indices not left as-is by indices_from_slice"
assert np.all(ind_np == ind_np_check), msg1

slc = slice(2, 10, 2)
ind_np_check = utils.indices_from_slice(slc)
msg2 = "Slice not converted properly by indices_from_slice"
assert np.all(ind_np == ind_np_check), msg2

def test_psd(self):
"""Test PSD functions."""
Tmax = self.psr.toas.max() - self.psr.toas.min()
Expand Down
Loading

0 comments on commit d782c29

Please sign in to comment.