Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-output support to GP Latent #7471

Merged
merged 3 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 61 additions & 13 deletions pymc/gp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,37 @@
def __init__(self, *, mean_func=Zero(), cov_func=Constant(0.0)):
super().__init__(mean_func=mean_func, cov_func=cov_func)

def _build_prior(self, name, X, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs):
def _build_prior(
self, name, X, n_outputs=1, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs
):
mu = self.mean_func(X)
cov = stabilize(self.cov_func(X), jitter)
if reparameterize:
size = np.shape(X)[0]
v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=size, **kwargs)
f = pm.Deterministic(name, mu + cholesky(cov).dot(v), dims=kwargs.get("dims", None))
if "dims" in kwargs:
v = pm.Normal(

Check warning on line 158 in pymc/gp/gp.py

View check run for this annotation

Codecov / codecov/patch

pymc/gp/gp.py#L158

Added line #L158 was not covered by tests
name + "_rotated_",
mu=0.0,
sigma=1.0,
**kwargs,
)

else:
size = (n_outputs, X.shape[0]) if n_outputs > 1 else X.shape[0]
v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=size, **kwargs)

f = pm.Deterministic(
name,
mu + cholesky(cov).dot(v.T).transpose(),
dims=kwargs.get("dims", None),
)

else:
f = pm.MvNormal(name, mu=mu, cov=cov, **kwargs)
mu_stack = pt.stack([mu] * n_outputs, axis=0) if n_outputs > 1 else mu
f = pm.MvNormal(name, mu=mu_stack, cov=cov, **kwargs)

return f

def prior(self, name, X, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs):
def prior(self, name, X, n_outputs=1, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs):
R"""
Returns the GP prior distribution evaluated over the input
locations `X`.
Expand All @@ -178,6 +197,12 @@
X : array-like
Function input values. If one-dimensional, must be a column
vector with shape `(n, 1)`.
n_outputs : int, default 1
Number of output GPs. If you're using `dims`, make sure their size
is equal to `(n_outputs, X.shape[0])`, i.e the number of output GPs
by the number of input points.
Example: `gp.prior("f", X=X, n_outputs=3, dims=("n_gps", "x_dim"))`,
where `len(n_gps) = 3` and `len(x_dim = X.shape[0]`.
reparameterize : bool, default True
Reparameterize the distribution by rotating the random
variable by the Cholesky factor of the covariance matrix.
Expand All @@ -188,10 +213,12 @@
Extra keyword arguments that are passed to :class:`~pymc.MvNormal`
distribution constructor.
"""
f = self._build_prior(name, X, n_outputs, reparameterize, jitter, **kwargs)

f = self._build_prior(name, X, reparameterize, jitter, **kwargs)
self.X = X
self.f = f
self.n_outputs = n_outputs

return f

def _get_given_vals(self, given):
Expand All @@ -212,12 +239,16 @@
def _build_conditional(self, Xnew, X, f, cov_total, mean_total, jitter):
Kxx = cov_total(X)
Kxs = self.cov_func(X, Xnew)

L = cholesky(stabilize(Kxx, jitter))
A = solve_lower(L, Kxs)
v = solve_lower(L, f - mean_total(X))
mu = self.mean_func(Xnew) + pt.dot(pt.transpose(A), v)
v = solve_lower(L, (f - mean_total(X)).T)

mu = self.mean_func(Xnew) + pt.dot(pt.transpose(A), v).T

Kss = self.cov_func(Xnew)
cov = Kss - pt.dot(pt.transpose(A), A)

return mu, cov

def conditional(self, name, Xnew, given=None, jitter=JITTER_DEFAULT, **kwargs):
Expand Down Expand Up @@ -255,7 +286,9 @@
"""
givens = self._get_given_vals(given)
mu, cov = self._build_conditional(Xnew, *givens, jitter)
return pm.MvNormal(name, mu=mu, cov=cov, **kwargs)
f = pm.MvNormal(name, mu=mu, cov=cov, **kwargs)

return f


@conditioned_vars(["X", "f", "nu"])
Expand Down Expand Up @@ -447,7 +480,15 @@
return mu, stabilize(cov, jitter)

def marginal_likelihood(
self, name, X, y, sigma=None, noise=None, jitter=JITTER_DEFAULT, is_observed=True, **kwargs
self,
name,
X,
y,
sigma=None,
noise=None,
jitter=JITTER_DEFAULT,
is_observed=True,
**kwargs,
):
R"""
Returns the marginal likelihood distribution, given the input
Expand Down Expand Up @@ -529,21 +570,28 @@
Kxs = self.cov_func(X, Xnew)
Knx = noise_func(X)
rxx = y - mean_total(X)

L = cholesky(stabilize(Kxx, jitter) + Knx)
A = solve_lower(L, Kxs)
v = solve_lower(L, rxx)
mu = self.mean_func(Xnew) + pt.dot(pt.transpose(A), v)
v = solve_lower(L, rxx.T)
mu = self.mean_func(Xnew) + pt.dot(pt.transpose(A), v).T

if diag:
Kss = self.cov_func(Xnew, diag=True)
var = Kss - pt.sum(pt.square(A), 0)

if pred_noise:
var += noise_func(Xnew, diag=True)

return mu, var

else:
Kss = self.cov_func(Xnew)
cov = Kss - pt.dot(pt.transpose(A), A)

if pred_noise:
cov += noise_func(Xnew)

return mu, cov if pred_noise else stabilize(cov, jitter)

def conditional(
Expand Down
13 changes: 9 additions & 4 deletions pymc/gp/hsgp_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,18 +442,23 @@ def prior(
Dimension name for the GP random variable.
"""
phi, sqrt_psd = self.prior_linearized(X)
self._sqrt_psd = sqrt_psd

if self._parametrization == "noncentered":
self._beta = pm.Normal(
f"{name}_hsgp_coeffs_",
size=self._m_star - int(self._drop_first),
f"{name}_hsgp_coeffs",
size=self.n_basis_vectors - int(self._drop_first),
dims=hsgp_coeffs_dims,
)
self._sqrt_psd = sqrt_psd
f = self.mean_func(X) + phi @ (self._beta * self._sqrt_psd)

elif self._parametrization == "centered":
self._beta = pm.Normal(f"{name}_hsgp_coeffs_", sigma=sqrt_psd, dims=hsgp_coeffs_dims)
self._beta = pm.Normal(
f"{name}_hsgp_coeffs",
sigma=sqrt_psd,
size=self.n_basis_vectors - int(self._drop_first),
dims=hsgp_coeffs_dims,
)
f = self.mean_func(X) + phi @ self._beta

self.f = pm.Deterministic(name, f, dims=gp_dims)
Expand Down
75 changes: 70 additions & 5 deletions tests/gp/test_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import numpy.testing as npt
import pytensor.tensor as pt
import pytest

import pymc as pm
Expand Down Expand Up @@ -90,7 +91,12 @@ def test_raise_value_error(self):
with self.model:
with pytest.raises(ValueError):
self.gp.marginal_likelihood(
"like_both", X=self.x, Xu=self.xu, y=self.y, noise=self.sigma, sigma=self.sigma
"like_both",
X=self.x,
Xu=self.xu,
y=self.y,
noise=self.sigma,
sigma=self.sigma,
)

with pytest.raises(ValueError):
Expand Down Expand Up @@ -177,7 +183,11 @@ def setup_method(self):
pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3]),
pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3]),
)
self.means = (pm.gp.mean.Constant(0.5), pm.gp.mean.Constant(0.5), pm.gp.mean.Constant(0.5))
self.means = (
pm.gp.mean.Constant(0.5),
pm.gp.mean.Constant(0.5),
pm.gp.mean.Constant(0.5),
)

def testAdditiveMarginal(self):
with pm.Model() as model1:
Expand All @@ -199,7 +209,9 @@ def testAdditiveMarginal(self):

with model1:
fp1 = gpsum.conditional(
"fp1", self.Xnew, given={"X": self.X, "y": self.y, "sigma": self.noise, "gp": gpsum}
"fp1",
self.Xnew,
given={"X": self.X, "y": self.y, "sigma": self.noise, "gp": gpsum},
)
with model2:
fp2 = gptot.conditional("fp2", self.Xnew)
Expand Down Expand Up @@ -230,7 +242,9 @@ def testAdditiveMarginalApprox(self, approx):

with pm.Model() as model2:
gptot = pm.gp.MarginalApprox(
mean_func=reduce(add, self.means), cov_func=reduce(add, self.covs), approx=approx
mean_func=reduce(add, self.means),
cov_func=reduce(add, self.covs),
approx=approx,
)
fsum = gptot.marginal_likelihood("f", self.X, Xu, self.y, sigma=sigma)
model2_logp = model2.compile_logp()({})
Expand Down Expand Up @@ -352,6 +366,53 @@ def testLatent2(self):
latent_logp = model.compile_logp()({"f_rotated_": y_rotated, "p": self.pnew})
npt.assert_allclose(latent_logp, self.logp, atol=5)

def testLatentMultioutput(self):
n_outputs = 2
X = np.random.randn(20, 3)
y = np.random.randn(n_outputs, 20)
Xnew = np.random.randn(30, 3)
pnew = np.random.randn(n_outputs, 30)

with pm.Model() as latent_model:
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
mean_func = pm.gp.mean.Constant(0.5)
latent_gp = pm.gp.Latent(mean_func=mean_func, cov_func=cov_func)
latent_f = latent_gp.prior("f", X, n_outputs=n_outputs, reparameterize=True)
latent_p = latent_gp.conditional("p", Xnew)

with pm.Model() as marginal_model:
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
mean_func = pm.gp.mean.Constant(0.5)
marginal_gp = pm.gp.Marginal(mean_func=mean_func, cov_func=cov_func)
marginal_f = marginal_gp.marginal_likelihood("f", X, y, sigma=0.0)
marginal_p = marginal_gp.conditional("p", Xnew)

assert tuple(latent_f.shape.eval()) == tuple(marginal_f.shape.eval()) == y.shape
assert tuple(latent_p.shape.eval()) == tuple(marginal_p.shape.eval()) == pnew.shape

chol = np.linalg.cholesky(cov_func(X).eval())
v = np.linalg.solve(chol, (y - 0.5).T)
A = np.linalg.solve(chol, cov_func(X, Xnew).eval()).T
mu_cond = mean_func(Xnew).eval() + (A @ v).T
cov_cond = cov_func(Xnew, Xnew).eval() - A @ A.T

with pm.Model() as numpy_model:
numpy_p = pm.MvNormal.dist(mu=pt.as_tensor(mu_cond), cov=pt.as_tensor(cov_cond))

latent_rv_logp = pm.logp(latent_p, pnew)
marginal_rv_logp = pm.logp(marginal_p, pnew)
numpy_rv_logp = pm.logp(numpy_p, pnew)

assert (
latent_rv_logp.shape.eval()
== marginal_rv_logp.shape.eval()
== numpy_rv_logp.shape.eval()
)

npt.assert_allclose(latent_rv_logp.eval(), marginal_rv_logp.eval(), atol=5)
npt.assert_allclose(latent_rv_logp.eval(), numpy_rv_logp.eval(), atol=5)
npt.assert_allclose(marginal_rv_logp.eval(), numpy_rv_logp.eval(), atol=5)


class TestTP:
R"""
Expand Down Expand Up @@ -486,7 +547,11 @@ def setup_method(self):
self.X = cartesian(*self.Xs)
self.N = np.prod([len(X) for X in self.Xs])
self.y = np.random.randn(self.N) * 0.1
self.Xnews = (np.random.randn(5, 1), np.random.randn(5, 1), np.random.randn(5, 1))
self.Xnews = (
np.random.randn(5, 1),
np.random.randn(5, 1),
np.random.randn(5, 1),
)
self.Xnew = np.concatenate(self.Xnews, axis=1)
self.sigma = 0.2
self.pnew = np.random.randn(len(self.Xnew))
Expand Down
2 changes: 1 addition & 1 deletion tests/gp/test_hsgp_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_parametrization_drop_first(self, model, cov_func, X1, drop_first):
gp = pm.gp.HSGP(m=[n_basis], c=4.0, cov_func=cov_func, drop_first=drop_first)
gp.prior("f1", X1)

n_coeffs = model.f1_hsgp_coeffs_.type.shape[0]
n_coeffs = model.f1_hsgp_coeffs.type.shape[0]
if drop_first:
assert (
n_coeffs == n_basis - 1
Expand Down
Loading