Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement whitened parametrisation #71

Merged
merged 33 commits into from
Nov 16, 2021
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4d4c182
Implement whitened parametrisation
willtebbutt Nov 5, 2021
ec8d49d
Bump patch
willtebbutt Nov 5, 2021
b0e69d2
Improve docs
willtebbutt Nov 5, 2021
a11f40e
Apply suggestions from code review
willtebbutt Nov 5, 2021
145d7f9
Update docs
willtebbutt Nov 5, 2021
be44580
SVGP -> SparseVariationalApproximation
willtebbutt Nov 5, 2021
523e8ec
Merge branch 'wct/whitened-inference' of https://github.com/JuliaGaus…
willtebbutt Nov 5, 2021
065041e
Fix docstring typo
willtebbutt Nov 5, 2021
093776e
Update docs
willtebbutt Nov 8, 2021
f6adf96
Refactor to use type parameter
willtebbutt Nov 8, 2021
8cef7ba
Test kl_term
willtebbutt Nov 8, 2021
34c5e7c
Merge in changes
willtebbutt Nov 8, 2021
f1abc2e
Run all tests
willtebbutt Nov 8, 2021
c688025
Clarify tests
willtebbutt Nov 8, 2021
cbdae25
Apply suggestions from code review
willtebbutt Nov 8, 2021
7bfa962
Stabilise numerics in sparse_variational tests
willtebbutt Nov 8, 2021
e70a176
Merge in master
willtebbutt Nov 8, 2021
3965248
Stabilise tests
willtebbutt Nov 8, 2021
46fb9ec
Improve docs
willtebbutt Nov 8, 2021
b2a25fb
Add Gorinova reference
willtebbutt Nov 8, 2021
9deb019
Add whitening transformation ref
willtebbutt Nov 8, 2021
bec6554
Merge in master
willtebbutt Nov 12, 2021
f0d0ea5
Fix tests
willtebbutt Nov 12, 2021
ab1fc8e
Use American English :(
willtebbutt Nov 12, 2021
80c3c13
Update test/sparse_variational.jl
willtebbutt Nov 12, 2021
9c7d440
Typos
willtebbutt Nov 12, 2021
9268786
Apply Theo's suggestions
willtebbutt Nov 13, 2021
1c3ea5b
Update src/sparse_variational.jl
willtebbutt Nov 14, 2021
65e382d
Fix for docs
theogf Nov 15, 2021
4ae4534
Fix rest of the docs
theogf Nov 15, 2021
13bcc08
Apply Ti's formatting suggestions
willtebbutt Nov 16, 2021
c1babe5
Add Paciorek reference
willtebbutt Nov 16, 2021
237a56e
Bump patch
willtebbutt Nov 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ApproximateGPs"
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
authors = ["JuliaGaussianProcesses Team"]
version = "0.2.0"
version = "0.2.1"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand Down
32 changes: 28 additions & 4 deletions docs/src/userguide.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,46 @@ To construct a sparse approximation to the exact posterior, we first need to sel
M = 15 # The number of inducing points
z = x[1:M]
```
The inducing inputs `z` imply some latent function values `u = f(z)`, sometimes called pseudo-points. The stochastic variational Gaussian process (SVGP) approximation is defined by a variational distribution `q(u)` over the pseudo-points. In the case of GP regression, the optimal form for `q(u)` is a multivariate Gaussian, which is the only form of `q` currently supported by this package.
The inducing inputs `z` imply some latent function values `u = f(z)`, sometimes called pseudo-points. The `SparseVariationalApproximation` specifies a distribution `q(u)` over the pseudo-points. In the case of GP regression, the optimal form for `q(u)` is a multivariate Gaussian, which is the only form of `q` currently supported by this package.
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
```julia
using Distributions, LinearAlgebra
q = MvNormal(zeros(length(z)), I)
```
Finally, we pass our `q` along with the inputs `f(z)` to obtain an approximate posterior GP:
```julia
fz = f(z, 1e-6) # 'observe' the process at z with some jitter for numerical stability
approx = SVGP(fz, q) # Instantiate everything needed for the svgp approximation
approx = SparseVariationalApproximation(fz, q) # Instantiate everything needed for the approximation
st-- marked this conversation as resolved.
Show resolved Hide resolved

svgp_posterior = posterior(approx) # Create the approximate posterior
sva_posterior = posterior(approx) # Create the approximate posterior
```

## The Evidence Lower Bound (ELBO)
The approximate posterior constructed above will be a very poor approximation, since `q` was simply chosen to have zero mean and covariance `I`. A measure of the quality of the approximation is given by the ELBO. Optimising this term with respect to the parameters of `q` and the inducing input locations `z` will improve the approximation.
```julia
elbo(SVGP(fz, q), fx, y)
elbo(SparseVariationalApproximation(fz, q), fx, y)
```
A detailed example of how to carry out such optimisation is given in [Regression: Sparse Variational Gaussian Process for Stochastic Optimisation with Flux.jl](@ref). For an example of non-conjugate inference, see [Classification: Sparse Variational Approximation for Non-Conjugate Likelihoods with Optim's L-BFGS](@ref).

# Available Parametrisations

Two parametrisations of `q(u)` are presently available: centred and non-centred.
The centred parametrisation expresses `q(u)` directly in terms of its mean and covariance.
The non-centred parametrisation instead parametrises the mean and covariance of
`ε := cholesky(cov(u)).U' \ (u - mean(u))`.
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved

The choice of parametrisation can have a substantial impact on the time it takes for ELBO
optimisation to converge, and which parametrisation is better in a particular situation is
not generally obvious.
That being said, the non-centred parametrisation is often superior, so it is the default --
it is what is used in all of the examples above.

If you require a particular parametrisation, simply use the 3-argument version of the
approximation constructor:
```julia
SparseVariationalApproximation(Centred(), fz, q)
SparseVariationalApproximation(NonCentred(), fz, q)
```

For a discussion around these two parametrisations, see e.g. [^Gorinova]

[^Gorinova]: Gorinova, Maria and Moore, Dave and Hoffman, Matthew [Automatic Reparameterisation of Probabilistic Programs](http://proceedings.mlr.press/v119/gorinova20a)
15 changes: 12 additions & 3 deletions src/ApproximateGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,18 @@ using ChainRulesCore
using FillArrays
using KLDivergences

using AbstractGPs: AbstractGP, FiniteGP, LatentFiniteGP, ApproxPosteriorGP, At_A, diag_At_A

export SparseVariationalApproximation
using AbstractGPs:
AbstractGP,
FiniteGP,
LatentFiniteGP,
ApproxPosteriorGP,
At_A,
diag_At_A,
Xt_A_X,
Xt_A_Y,
diag_Xt_A_X

export SparseVariationalApproximation, Centred, NonCentred
export DefaultQuadrature, Analytic, GaussHermite, MonteCarlo

include("utils.jl")
Expand Down
27 changes: 22 additions & 5 deletions src/elbo.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
"""
elbo(svgp::SparseVariationalApproximation, fx::FiniteGP, y::AbstractVector{<:Real}; num_data=length(y), quadrature=DefaultQuadrature())
elbo(
sva::SparseVariationalApproximation,
fx::FiniteGP,
y::AbstractVector{<:Real};
num_data=length(y),
quadrature=DefaultQuadrature(),
)

Compute the Evidence Lower BOund from [1] for the process `f = fx.f ==
svgp.fz.f` where `y` are observations of `fx`, pseudo-inputs are given by `z =
Expand Down Expand Up @@ -39,7 +45,13 @@ function AbstractGPs.elbo(
end

"""
elbo(svgp, ::SparseVariationalApproximation, lfx::LatentFiniteGP, y::AbstractVector; num_data=length(y), quadrature=DefaultQuadrature())
elbo(
sva::SparseVariationalApproximation,
lfx::LatentFiniteGP,
y::AbstractVector;
num_data=length(y),
quadrature=DefaultQuadrature(),
)

Compute the ELBO for a LatentGP with a possibly non-conjugate likelihood.
"""
Expand Down Expand Up @@ -68,9 +80,14 @@ function _elbo(
q_f = marginals(post(fx.x))
variational_exp = expected_loglik(quadrature, y, q_f, lik)

kl_term = KL(sva.q, sva.fz)

n_batch = length(y)
scale = num_data / n_batch
return sum(variational_exp) * scale - kl_term
return sum(variational_exp) * scale - kl_term(sva, post)
end

kl_term(sva::SparseVariationalApproximation{Centred}, post) = KL(sva.q, sva.fz)

function kl_term(sva::SparseVariationalApproximation{NonCentred}, post)
m_ε = mean(sva.q)
return (tr(cov(sva.q)) + m_ε'm_ε - length(m_ε) - logdet(post.data.C_ε)) / 2
end
174 changes: 160 additions & 14 deletions src/sparse_variational.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,72 @@
raw"""
Centred()

Used in conjunction with `SparseVariationalApproximation`.
States that the `q` field of [`SparseVariationalApproximation`](@ref) is to be interpreted
directly as the approximate posterior over the pseudo-points.

This is also known as the "unwhitened" parametrisation [1].

See also [`NonCentred`](@ref).

[1] - https://en.wikipedia.org/wiki/Whitening_transformation
"""
SparseVariationalApproximation(fz::FiniteGP, q::AbstractMvNormal)
struct Centred end
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved

raw"""
NonCentred()

Used in conjunction with `SparseVariationalApproximation`.
States that the `q` field of [`SparseVariationalApproximation`](@ref) is to be interpreted
as the approximate posterior over `cholesky(cov(u)).L \ (u - mean(u))`, where `u` are the
pseudo-points.

This is also known as the "whitened" parametrisation [1].

Packages the prior over the pseudo-points, `fz`, and the approximate posterior at the
pseudo-points, `q`, together into a single object.
See also [`Centred`](@ref).

[1] - https://en.wikipedia.org/wiki/Whitening_transformation
"""
struct SparseVariationalApproximation{Tfz<:FiniteGP,Tq<:AbstractMvNormal}
struct NonCentred end

struct SparseVariationalApproximation{Parametrisation,Tfz<:FiniteGP,Tq<:AbstractMvNormal}
fz::Tfz
q::Tq
end

raw"""
posterior(sva::SparseVariationalApproximation)
SparseVariationalApproximation(::Parametrisation, fz::FiniteGP, q::AbstractMvNormal)
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved

Produce a `SparseVariationalApproximation{Parametrisation}`, which packages the prior over
the pseudo-points, `fz`, and the approximate posterior at the pseudo-points, `q`, together
into a single object.

The `Parametrisation` determines the precise manner in which `q` and `fz` are interpreted.
Existing parametrisations include [`Centred`](@ref) and [`NonCentred`](@ref).
"""
function SparseVariationalApproximation(
::Parametrisation, fz::Tfz, q::Tq
) where {Parametrisation,Tfz<:FiniteGP,Tq<:AbstractMvNormal}
return SparseVariationalApproximation{Parametrisation,Tfz,Tq}(fz, q)
end

"""
SparseVariationalApproximation(fz::FiniteGP, q::AbstractMvNormal)

Packages the prior over the pseudo-points `fz`, and the approximate posterior at the
pseudo-points, which is `mean(fz) + cholesky(cov(fz)).U' * ε`, `ε ∼ q`.

Shorthand for
```julia
SparseVariationalApproximation(NonCentred(), fz, q)
```
"""
function SparseVariationalApproximation(fz::FiniteGP, q::AbstractMvNormal)
return SparseVariationalApproximation(NonCentred(), fz, q)
end

raw"""
posterior(sva::SparseVariationalApproximation{Centred})

Compute the approximate posterior [1] over the process `f =
sva.fz.f`, given inducing inputs `z = sva.fz.x` and a variational
Expand All @@ -27,7 +83,7 @@ which can be found in closed form.
variational Gaussian process classification." Artificial Intelligence and
Statistics. PMLR, 2015.
"""
function AbstractGPs.posterior(sva::SparseVariationalApproximation)
function AbstractGPs.posterior(sva::SparseVariationalApproximation{Centred})
q, fz = sva.q, sva.fz
m, S = mean(q), _chol_cov(q)
Kuu = _chol_cov(fz)
Expand All @@ -38,41 +94,41 @@ function AbstractGPs.posterior(sva::SparseVariationalApproximation)
end

function AbstractGPs.posterior(
sva::SparseVariationalApproximation, fx::FiniteGP, ::AbstractVector
sva::SparseVariationalApproximation, fx::FiniteGP, ::AbstractVector{<:Real}
)
@assert sva.fz.f === fx.f
return posterior(sva)
end

#
# Code below this point just implements the Internal AbstractGPs API.
# Various methods implementing the Internal AbstractGPs API.
# See AbstractGPs.jl API docs for more info.
#

function Statistics.mean(
f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centred}}, x::AbstractVector
)
return mean(f.prior, x) + cov(f.prior, x, inducing_points(f)) * f.data.α
end

function Statistics.cov(
f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centred}}, x::AbstractVector
)
Cux = cov(f.prior, inducing_points(f), x)
D = f.data.Kuu.L \ Cux
return cov(f.prior, x) - At_A(D) + At_A(f.data.B' * D)
end

function Statistics.var(
f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centred}}, x::AbstractVector
)
Cux = cov(f.prior, inducing_points(f), x)
D = f.data.Kuu.L \ Cux
return var(f.prior, x) - diag_At_A(D) + diag_At_A(f.data.B' * D)
end

function Statistics.cov(
f::ApproxPosteriorGP{<:SparseVariationalApproximation},
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centred}},
x::AbstractVector,
y::AbstractVector,
)
Expand All @@ -85,7 +141,7 @@ function Statistics.cov(
end

function StatsBase.mean_and_cov(
f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centred}}, x::AbstractVector
)
Cux = cov(f.prior, inducing_points(f), x)
D = f.data.Kuu.L \ Cux
Expand All @@ -95,7 +151,7 @@ function StatsBase.mean_and_cov(
end

function StatsBase.mean_and_var(
f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector
f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centred}}, x::AbstractVector
)
Cux = cov(f.prior, inducing_points(f), x)
D = f.data.Kuu.L \ Cux
Expand All @@ -104,6 +160,96 @@ function StatsBase.mean_and_var(
return μ, Σ_diag
end

#
# NonCentred parametrisation.
#

raw"""
posterior(sva::SparseVariationalApproximation{NonCentred})

Compute the approximate posterior [1] over the process `f =
sva.fz.f`, given inducing inputs `z = sva.fz.x` and a variational
distribution over inducing points `sva.q` (which represents ``q(ε)``
where `ε = cholesky(cov(fz)).U' \ (f(z) - mean(f(z)))`). The approximate posterior at test
points ``x^*`` where ``f^* = f(x^*)`` is then given by:

```math
q(f^*) = \int p(f | ε) q(ε) du
```
which can be found in closed form.

[1] - Hensman, James, Alexander Matthews, and Zoubin Ghahramani. "Scalable
variational Gaussian process classification." Artificial Intelligence and
Statistics. PMLR, 2015.
"""
function AbstractGPs.posterior(approx::SparseVariationalApproximation{NonCentred})
fz = approx.fz
data = (Cuu=_chol_cov(fz), C_ε=_chol_cov(approx.q))
return ApproxPosteriorGP(approx, fz.f, data)
end

#
# Various methods implementing the Internal AbstractGPs API.
# See AbstractGPs.jl API docs for more info.
#

# Produces a matrix that is consistently referred to as A in this file. A more descriptive
# name is, unfortunately, not obvious. It's just an intermediate quantity that happens to
# get used a lot.
_A(f, x) = f.data.Cuu.U' \ cov(f.prior, inducing_points(f), x)

function Statistics.mean(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentred}}, x::AbstractVector
)
return mean(f.prior, x) + _A(f, x)' * mean(f.approx.q)
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
end

function Statistics.cov(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentred}}, x::AbstractVector
)
A = _A(f, x)
return cov(f.prior, x) - At_A(A) + Xt_A_X(f.data.C_ε, A)
end

function Statistics.var(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentred}}, x::AbstractVector
)
A = _A(f, x)
return var(f.prior, x) - diag_At_A(A) + diag_Xt_A_X(f.data.C_ε, A)
end

function Statistics.cov(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentred}},
x::AbstractVector,
y::AbstractVector,
)
Ax = _A(f, x)
Ay = _A(f, y)
return cov(f.prior, x, y) - Ax'Ay + Xt_A_Y(Ax, f.data.C_ε, Ay)
end

function StatsBase.mean_and_cov(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentred}}, x::AbstractVector
)
A = _A(f, x)
μ = mean(f.prior, x) + A' * mean(f.approx.q)
Σ = cov(f.prior, x) - At_A(A) + Xt_A_X(f.data.C_ε, A)
return μ, Σ
end

function StatsBase.mean_and_var(
f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentred}}, x::AbstractVector
)
A = _A(f, x)
μ = mean(f.prior, x) + A' * mean(f.approx.q)
Σ = var(f.prior, x) - diag_At_A(A) + diag_Xt_A_X(f.data.C_ε, A)
return μ, Σ
end

#
# Misc utility.
#

inducing_points(f::ApproxPosteriorGP{<:SparseVariationalApproximation}) = f.approx.fz.x

_chol_cov(q::AbstractMvNormal) = cholesky(Symmetric(cov(q)))
Expand Down
4 changes: 2 additions & 2 deletions test/elbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
f = GP(kernel)
fx = f(x, 0.1)
fz = f(z)
q_ex = exact_variational_posterior(fz, fx, y)
q_ex = optimal_variational_posterior(fz, fx, y)

sva = SparseVariationalApproximation(fz, q_ex)
sva = SparseVariationalApproximation(Centred(), fz, q_ex)
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
@test elbo(sva, fx, y) isa Real
@test elbo(sva, fx, y) ≤ logpdf(fx, y)

Expand Down
Loading