Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]: Implementing Kronecker terms #6

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b2b8ae2
starting to implement kron terms
dfm Sep 16, 2020
c252a17
dealing with shapes
dfm Sep 16, 2020
bc93ea7
fixing order of columns in kron term
dfm Oct 14, 2020
1a86887
caching theano directory
dfm Oct 14, 2020
c122d8c
cache based on compiledir
dfm Oct 14, 2020
93888d8
forgotten shell command
dfm Oct 14, 2020
a7a4791
adding implementation of low rank kron
dfm Oct 14, 2020
12e3883
implementing kron term sum
dfm Oct 14, 2020
3289b49
skipping termsumgeneral
dfm Oct 14, 2020
303bdf1
xfail
dfm Oct 14, 2020
6be5390
cache theano for tutorials too
dfm Oct 14, 2020
c746b9f
handling latent dimensions generally
dfm Oct 15, 2020
e086162
Merge branch 'kron' of https://github.com/exoplanet-dev/celerite2 int…
dfm Oct 15, 2020
256ba4d
adding tests for predict and sample
dfm Oct 16, 2020
8354026
fixing isort
dfm Oct 16, 2020
cbca530
fixing variance dimensions
dfm Oct 16, 2020
1d38e1c
typo
dfm Oct 16, 2020
5cb6d3a
starting to implement theano terms
dfm Oct 20, 2020
b26bbf2
updating theano gp for kron
dfm Oct 20, 2020
bb88ce2
updating jax an torch
dfm Oct 20, 2020
1430f46
sorting imports
dfm Oct 20, 2020
9a77620
typos
dfm Oct 20, 2020
4d6d3ca
torch sizes
dfm Oct 20, 2020
dd91eb9
torch size again
dfm Oct 21, 2020
badd729
adding tests for citations
dfm Oct 21, 2020
7bd2ce3
adding kron tutorial
dfm Oct 21, 2020
40ee27a
dealing with testvals
dfm Oct 21, 2020
e71fb6c
adding docstrings for kron terms
dfm Oct 21, 2020
f978884
simplifying kron needs
dfm Oct 22, 2020
3865fbe
adding initial support for missing data
dfm Nov 2, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ jobs:
conda install -q numpy scipy theano mkl-service
python -m pip install -U pip
python -m pip install --use-feature=2020-resolver -e ".[test,theano]"
- name: Get theano compiledir
id: compiledir
shell: bash -l {0}
run: |
python -c "import theano; print('::set-output name=compiledir::' + theano.config.compiledir.split('/')[-1])"
- name: "Cache ~/.theano"
uses: actions/cache@v2
with:
path: ~/.theano
key: theano-${{ steps.compiledir.outputs.compiledir }}-${{ hashFiles('python/test/theano/*.py') }}
restore-keys: |
theano-${{ steps.compiledir.outputs.compiledir }}-
theano-
- name: Run the unit tests
shell: bash -l {0}
run: python -m pytest --cov celerite2 python/test/theano
Expand Down
14 changes: 14 additions & 0 deletions .github/workflows/tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@ jobs:
python -m pip install -U pip
python -m pip install --use-feature=2020-resolver ".[tutorials]"

- name: Get theano compiledir
id: compiledir
run: |
python -c "import theano; print('::set-output name=compiledir::' + theano.config.compiledir.split('/')[-1])"

- name: "Cache ~/.theano"
uses: actions/cache@v2
with:
path: ~/.theano
key: tutorials-${{ steps.compiledir.outputs.compiledir }}-${{ hashFiles('docs/tutorials/*.py') }}
restore-keys: |
tutorials-${{ steps.compiledir.outputs.compiledir }}-
tutorials-

- name: Execute the notebooks
run: |
jupytext --to ipynb --execute docs/tutorials/*.py
Expand Down
1 change: 1 addition & 0 deletions docs/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
_build
c++
tutorials/*.png
12 changes: 12 additions & 0 deletions docs/api/python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,15 @@ recommended unless you're confident that you know what you're doing.
.. autoclass:: celerite2.terms.RealTerm
.. autoclass:: celerite2.terms.ComplexTerm
.. autoclass:: celerite2.terms.OriginalCeleriteTerm

Multivariate models
-------------------

The original *celerite* algorithm was only defined for one dimensional inputs,
but this was generalized by `Gordon et al. (2020)
<https://arxiv.org/abs/2007.05799>`_ to support multivariate inputs on tensor
product grids with separable kernels. In this case, the covariance matrix is
given by a Kronecker product. These models are now available in *celerite2*
using the following:

.. autoclass:: celerite2.kron.KronTerm
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ be a good choice.
:caption: Tutorials

tutorials/first.ipynb
tutorials/kron.ipynb

.. toctree::
:maxdepth: 2
Expand Down
5 changes: 4 additions & 1 deletion docs/tutorials/first.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
np.random.seed(42)

t = np.sort(
np.append(np.random.uniform(0, 3.8, 57), np.random.uniform(5.5, 10, 68),)
np.append(
np.random.uniform(0, 3.8, 57),
np.random.uniform(5.5, 10, 68),
)
) # The input coordinates must be sorted
yerr = np.random.uniform(0.08, 0.22, len(t))
y = (
Expand Down
141 changes: 141 additions & 0 deletions docs/tutorials/kron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# -*- coding: utf-8 -*-
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.5.2
# kernelspec:
# display_name: Python 3
# language: python
# name: python3
# ---

# + nbsphinx="hidden"
# %matplotlib inline

# + nbsphinx="hidden"
# %run notebook_setup
# -

# # Multivariate models
#
# The original *celerite* package only supported one-dimensional data (like time series), but [Gordon et al. (2020)](https://arxiv.org/abs/2007.05799) generalized the method to multivariate data on a tensor product grid.
# This has been implmented in *celerite2* so "rectangular" data are now supported.
# The main application discussed by [Gordon et al. (2020)](https://arxiv.org/abs/2007.05799) was multiwavelength observations of transiting exoplanets, but this can be applicable to many other problems with the following structure:
#
# 1. The rectagular data must be fully filled: you must have observations in every band at every time. We've developed a method to handle missing data and that will be included in a future release.
# 2. The covariance matrix must be seperable with the form `k({x, y}_n, {x, y}_m) = k1(x_n, x_m) * k2(y_n, y_m)`, where `x` (a scalar) indexes the "longest" one-dimensional axis of the data (for example, time) and `y` (optionally a vector) indexes the narrower axis of the data (for example, wavelength). To apply *celerite*, we must make the further assumption that `k1(x_n, x_m)` is a standard *celerite* kernel, but no limitations are placed on the form of `k2(y_n, y_m)`.
#
# The implementation of this method in *celerite2* comes with two forms for the kernel:
#
# 1. `kron.KronTerm`: A general form of the model where `k2(y_n, y_m)` is specified as a full-rank `M x M` matrix called `R`, where `M` is the size of the `y` dimension. The computational cost of evaluating this model scales as `O(N * J^2 * M^3)` where `N` is the size of the `x` dimension and `J` is the rank of the *celerite* term describing `k1(x_n, x_m)`.
# 2. `kron.LowRankKronTerm`: A more computationally efficient method where

# +
import numpy as np
import matplotlib.pyplot as plt

import celerite2

N = 200
M = 5
lam = np.linspace(0, 3, M)

np.random.seed(59302)
t = np.append(
np.sort(np.random.uniform(0, 4, N // 2)),
np.sort(np.random.uniform(6, 10, N - N // 2)),
)
yerr = np.random.uniform(1e-1, 2e-1, (N, M))

rho_true = 4.5
R_true = 0.5 * np.exp(-0.5 * (lam[:, None] - lam[None, :]) ** 2)
kernel = celerite2.kron.KronTerm(
celerite2.terms.SHOTerm(sigma=1.0, rho=rho_true, Q=3.0), R=R_true
)
gp = celerite2.GaussianProcess(kernel, t=t, yerr=yerr)
y = gp.sample()

plt.yticks([])
for m in range(M):
plt.axhline(2 * m, color="k", lw=0.5)
plt.plot(t, y + 2 * np.arange(M), ".")
plt.ylim(-2, 2 * M)
plt.xlim(-1, 11)
plt.xlabel("x")
_ = plt.ylabel("y (with offsets)")

# +
import pymc3 as pm
import pymc3_ext as pmx
import celerite2.theano as cl2

with pm.Model() as model:

rho = pm.Lognormal("rho", mu=np.log(5.0), sigma=5.0)
chol = pm.LKJCholeskyCov(
"chol",
eta=10.0,
n=M,
sd_dist=pm.Exponential.dist(0.01),
compute_corr=True,
)[0]
R = pm.Deterministic("R", pm.math.dot(chol, chol.T))

kernel = cl2.kron.KronTerm(
cl2.terms.SHOTerm(sigma=1.0, rho=rho, Q=3.0), R=R
)
gp = cl2.GaussianProcess(kernel, t=t, yerr=yerr)
gp.marginal("obs", observed=y)

soln = pmx.optimize()

# +
t_pred = np.linspace(-1, 11, 1000)
with model:
mu, var = pmx.eval_in_model(gp.predict(y, t=t_pred, return_var=True), soln)

plt.yticks([])
for m in range(M):
plt.axhline(2 * m, color="k", lw=0.5)
plt.plot(t, y[:, m] + 2 * m, ".", color=f"C{m}")
plt.fill_between(
t_pred,
mu[:, m] - np.sqrt(var[:, m]) + 2 * m,
mu[:, m] + np.sqrt(var[:, m]) + 2 * m,
color=f"C{m}",
alpha=0.5,
)
plt.plot(t_pred, mu[:, m] + 2 * m, color=f"C{m}")

plt.ylim(-2, 2 * M)
plt.xlim(-1, 11)
plt.xlabel("x")
_ = plt.ylabel("y (with offsets)")
# -

with model:
trace = pm.sample(
tune=2000, draws=2000, target_accept=0.9, init="adapt_full"
)

plt.hist(trace["rho"], 50, histtype="step", color="k")
plt.axvline(rho_true)
plt.yticks([])
plt.xlabel(r"$\rho$")
plt.ylabel(r"$p(\rho)$")

for m in range(M):
plt.errorbar(
np.arange(M),
np.mean(trace["R"][:, m, :], axis=0) + m,
yerr=np.std(trace["R"][:, m, :], axis=0),
color=f"C{m}",
)
plt.plot(np.arange(M), R_true[m] + m, ":", color=f"C{m}")
plt.yticks([])
plt.xlabel("band index")
_ = plt.ylabel("covariance (with offsets)")
4 changes: 2 additions & 2 deletions python/celerite2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-

__all__ = ["__version__", "terms", "GaussianProcess"]
__all__ = ["__version__", "terms", "kron", "GaussianProcess"]

from . import terms
from . import kron, terms
from .celerite2 import GaussianProcess
from .celerite2_version import __version__

Expand Down
Loading