Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marginalapprox fix #6076

Merged
merged 6 commits into from
Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 4 additions & 44 deletions pymc/gp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,18 +685,13 @@ def __init__(self, approx="VFE", *, mean_func=Zero(), cov_func=Constant(0.0)):
super().__init__(mean_func=mean_func, cov_func=cov_func)

def __add__(self, other):
# new_gp will default to FITC approx
new_gp = super().__add__(other)
# make sure new gp has correct approx
if not self.approx == other.approx:
raise TypeError("Cannot add GPs with different approximations")
new_gp.approx = self.approx
return new_gp

# Use y as first argument, so that we can use functools.partial
# in marginal_likelihood instead of lambda. This makes pickling
# possible.
def _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter):
def _build_marginal_likelihood_loglik(self, y, X, Xu, sigma, jitter):
sigma2 = at.square(sigma)
Kuu = self.cov_func(Xu)
Kuf = self.cov_func(Xu, X)
Expand Down Expand Up @@ -725,9 +720,7 @@ def _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter):
quadratic = 0.5 * (at.dot(r, r_l) - at.dot(c, c))
return -1.0 * (constant + logdet + quadratic + trace)

def marginal_likelihood(
self, name, X, Xu, y, noise=None, is_observed=True, jitter=JITTER_DEFAULT, **kwargs
):
def marginal_likelihood(self, name, X, Xu, y, noise=None, jitter=JITTER_DEFAULT, **kwargs):
R"""
Returns the approximate marginal likelihood distribution, given the input
locations `X`, inducing point locations `Xu`, data `y`, and white noise
Expand All @@ -747,9 +740,6 @@ def marginal_likelihood(
noise. Must have shape `(n, )`.
noise: scalar, Variable
Standard deviation of the Gaussian noise.
is_observed: bool
Whether to set `y` as an `observed` variable in the `model`.
Default is `True`.
jitter: scalar
A small correction added to the diagonal of positive semi-definite
covariance matrices to ensure numerical stability.
Expand All @@ -767,38 +757,8 @@ def marginal_likelihood(
else:
self.sigma = noise

if is_observed:
return pm.DensityDist(
name,
X,
Xu,
self.sigma,
jitter,
logp=self._build_marginal_likelihood_logp,
observed=y,
ndims_params=[2, 2, 0],
size=X.shape[0],
**kwargs,
)
else:
warnings.warn(
"The 'is_observed' argument has been deprecated. If the GP is "
"unobserved use gp.Latent instead.",
FutureWarning,
)
return pm.DensityDist(
name,
X,
Xu,
self.sigma,
jitter,
logp=self._build_marginal_likelihood_logp,
observed=y,
ndims_params=[2, 2, 0],
# ndim_supp=1,
size=X.shape[0],
**kwargs,
)
approx_loglik = self._build_marginal_likelihood_loglik(y, X, Xu, noise, jitter)
pm.Potential(f"marginalapprox_loglik_{name}", approx_loglik, **kwargs)

def _build_conditional(
self, Xnew, pred_noise, diag, X, Xu, y, sigma, cov_total, mean_total, jitter
Expand Down
104 changes: 56 additions & 48 deletions pymc/tests/test_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,63 +846,71 @@ def testLatent2(self):

class TestMarginalVsMarginalApprox:
R"""
Compare logp of models Marginal and MarginalApprox.
Should be nearly equal when inducing points are same as inputs.
Compare test fits of models Marginal and MarginalApprox.
"""

def setup_method(self):
X = np.random.randn(50, 3)
y = np.random.randn(50)
Xnew = np.random.randn(60, 3)
pnew = np.random.randn(60)
with pm.Model() as model:
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
mean_func = pm.gp.mean.Constant(0.5)
gp = pm.gp.Marginal(mean_func=mean_func, cov_func=cov_func)
sigma = 0.1
f = gp.marginal_likelihood("f", X, y, noise=sigma)
p = gp.conditional("p", Xnew)
self.logp = model.compile_logp()({"p": pnew})
self.X = X
self.Xnew = Xnew
self.y = y
self.sigma = sigma
self.pnew = pnew
self.gp = gp
self.sigma = 0.1
self.x = np.linspace(-5, 5, 30)
self.y = np.random.normal(0.25 * self.x, self.sigma)
with pm.Model() as model:
cov_func = pm.gp.cov.Linear(1, c=0.0)
c = pm.Normal("c", mu=20.0, sigma=100.0) # far from true value
mean_func = pm.gp.mean.Constant(c)
self.gp = pm.gp.Marginal(mean_func=mean_func, cov_func=cov_func)
sigma = pm.HalfNormal("sigma", sigma=100)
self.gp.marginal_likelihood("lik", self.x[:, None], self.y, sigma)
self.map_full = pm.find_MAP(method="bfgs") # bfgs seems to work much better than lbfgsb

self.x_new = np.linspace(-6, 6, 20)

# Include additive Gaussian noise, return diagonal of predicted covariance matrix
with model:
self.pred_mu, self.pred_var = self.gp.predict(
self.x_new[:, None], point=self.map_full, pred_noise=True, diag=True
)

@pytest.mark.parametrize("approx", ["FITC", "VFE", "DTC"])
def testApproximations(self, approx):
with pm.Model() as model:
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
mean_func = pm.gp.mean.Constant(0.5)
gp = pm.gp.MarginalApprox(mean_func=mean_func, cov_func=cov_func, approx=approx)
f = gp.marginal_likelihood("f", self.X, self.X, self.y, self.sigma)
p = gp.conditional("p", self.Xnew)
approx_logp = model.compile_logp()({"p": self.pnew})
npt.assert_allclose(approx_logp, self.logp, atol=0, rtol=1e-2)
# Dont include additive Gaussian noise, return full predicted covariance matrix
with model:
self.pred_mu, self.pred_covar = self.gp.predict(
self.x_new[:, None], point=self.map_full, pred_noise=False, diag=False
)

@pytest.mark.parametrize("approx", ["FITC", "VFE", "DTC"])
def testPredictVar(self, approx):
def test_fits_and_preds(self, approx):
"""Get MAP estimate for GP approximation, compare results and predictions to what's returned
by an unapproximated GP. The tolerances are fairly wide, but narrow relative to initial
values of the unknown parameters.
"""

with pm.Model() as model:
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
mean_func = pm.gp.mean.Constant(0.5)
cov_func = pm.gp.cov.Linear(1, c=0.0)
c = pm.Normal("c", mu=20.0, sigma=100.0, initval=-500.0)
mean_func = pm.gp.mean.Constant(c)
gp = pm.gp.MarginalApprox(mean_func=mean_func, cov_func=cov_func, approx=approx)
f = gp.marginal_likelihood("f", self.X, self.X, self.y, self.sigma)
mu1, var1 = self.gp.predict(self.Xnew, diag=True)
mu2, var2 = gp.predict(self.Xnew, diag=True)
npt.assert_allclose(mu1, mu2, atol=0, rtol=1e-3)
npt.assert_allclose(var1, var2, atol=0, rtol=1e-3)
sigma = pm.HalfNormal("sigma", sigma=100, initval=50.0)
gp.marginal_likelihood("lik", self.x[:, None], self.x[:, None], self.y, sigma)
map_approx = pm.find_MAP(method="bfgs")

# Check MAP gets approximately correct result
npt.assert_allclose(self.map_full["c"], map_approx["c"], atol=0.01, rtol=0.1)
npt.assert_allclose(self.map_full["sigma"], map_approx["sigma"], atol=0.01, rtol=0.1)

# Check that predict (and conditional) work, include noise, with diagonal non-full pred var.
with model:
pred_mu_approx, pred_var_approx = gp.predict(
self.x_new[:, None], point=map_approx, pred_noise=True, diag=True
)
npt.assert_allclose(self.pred_mu, pred_mu_approx, atol=0.0, rtol=0.1)
npt.assert_allclose(self.pred_var, pred_var_approx, atol=0.0, rtol=0.1)

def testPredictCov(self):
canyon289 marked this conversation as resolved.
Show resolved Hide resolved
with pm.Model() as model:
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
mean_func = pm.gp.mean.Constant(0.5)
gp = pm.gp.MarginalApprox(mean_func=mean_func, cov_func=cov_func, approx="DTC")
f = gp.marginal_likelihood("f", self.X, self.X, self.y, self.sigma)
mu1, cov1 = self.gp.predict(self.Xnew, pred_noise=True)
mu2, cov2 = gp.predict(self.Xnew, pred_noise=True)
npt.assert_allclose(mu1, mu2, atol=0, rtol=1e-3)
npt.assert_allclose(cov1, cov2, atol=0, rtol=1e-3)
# Check that predict (and conditional) work, no noise, full pred covariance.
with model:
michaelosthege marked this conversation as resolved.
Show resolved Hide resolved
pred_mu_approx, pred_var_approx = gp.predict(
self.x_new[:, None], point=map_approx, pred_noise=True, diag=True
)
npt.assert_allclose(self.pred_mu, pred_mu_approx, atol=0.0, rtol=0.1)
npt.assert_allclose(self.pred_var, pred_var_approx, atol=0.0, rtol=0.1)


class TestGPAdditive:
Expand Down