Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement whitened parametrisation #71

Merged
merged 33 commits into from
Nov 16, 2021
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4d4c182
Implement whitened parametrisation
willtebbutt Nov 5, 2021
ec8d49d
Bump patch
willtebbutt Nov 5, 2021
b0e69d2
Improve docs
willtebbutt Nov 5, 2021
a11f40e
Apply suggestions from code review
willtebbutt Nov 5, 2021
145d7f9
Update docs
willtebbutt Nov 5, 2021
be44580
SVGP -> SparseVariationalApproximation
willtebbutt Nov 5, 2021
523e8ec
Merge branch 'wct/whitened-inference' of https://github.com/JuliaGaus…
willtebbutt Nov 5, 2021
065041e
Fix docstring typo
willtebbutt Nov 5, 2021
093776e
Update docs
willtebbutt Nov 8, 2021
f6adf96
Refactor to use type parameter
willtebbutt Nov 8, 2021
8cef7ba
Test kl_term
willtebbutt Nov 8, 2021
34c5e7c
Merge in changes
willtebbutt Nov 8, 2021
f1abc2e
Run all tests
willtebbutt Nov 8, 2021
c688025
Clarify tests
willtebbutt Nov 8, 2021
cbdae25
Apply suggestions from code review
willtebbutt Nov 8, 2021
7bfa962
Stabilise numerics in sparse_variational tests
willtebbutt Nov 8, 2021
e70a176
Merge in master
willtebbutt Nov 8, 2021
3965248
Stabilise tests
willtebbutt Nov 8, 2021
46fb9ec
Improve docs
willtebbutt Nov 8, 2021
b2a25fb
Add Gorinova reference
willtebbutt Nov 8, 2021
9deb019
Add whitening transformation ref
willtebbutt Nov 8, 2021
bec6554
Merge in master
willtebbutt Nov 12, 2021
f0d0ea5
Fix tests
willtebbutt Nov 12, 2021
ab1fc8e
Use American English :(
willtebbutt Nov 12, 2021
80c3c13
Update test/sparse_variational.jl
willtebbutt Nov 12, 2021
9c7d440
Typos
willtebbutt Nov 12, 2021
9268786
Apply Theo's suggestions
willtebbutt Nov 13, 2021
1c3ea5b
Update src/sparse_variational.jl
willtebbutt Nov 14, 2021
65e382d
Fix for docs
theogf Nov 15, 2021
4ae4534
Fix rest of the docs
theogf Nov 15, 2021
13bcc08
Apply Ti's formatting suggestions
willtebbutt Nov 16, 2021
c1babe5
Add Paciorek reference
willtebbutt Nov 16, 2021
237a56e
Bump patch
willtebbutt Nov 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ version = "0.5.3"
deps = ["AbstractGPs", "ChainRulesCore", "Distributions", "FastGaussQuadrature", "FillArrays", "ForwardDiff", "GPLikelihoods", "KLDivergences", "LinearAlgebra", "Reexport", "SpecialFunctions", "Statistics", "StatsBase"]
path = ".."
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
version = "0.2.0"
version = "0.2.1"

[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
Expand All @@ -32,6 +32,12 @@ git-tree-sha1 = "f885e7e7c124f8c92650d61b9477b9ac2ee607dd"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.11.1"

[[ChangesOfVariables]]
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "9a1d594397670492219635b35a3d830b04730d62"
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
version = "0.1.1"

[[CommonSubexpressions]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
Expand Down Expand Up @@ -72,6 +78,12 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[DensityInterface]]
deps = ["InverseFunctions", "Test"]
git-tree-sha1 = "794daf62dce7df839b8ed446fc59c68db4b5182f"
uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
version = "0.3.3"

[[DiffResults]]
deps = ["StaticArrays"]
git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805"
Expand All @@ -95,10 +107,10 @@ deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[Distributions]]
deps = ["ChainRulesCore", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"]
git-tree-sha1 = "72dcda9e19f88d09bf21b5f9507a0bb430bce2aa"
deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"]
git-tree-sha1 = "cce8159f0fee1281335a04bbf876572e46c921ba"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.24"
version = "0.25.29"

[[DocStringExtensions]]
deps = ["LibGit2"]
Expand Down Expand Up @@ -129,21 +141,21 @@ uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.12.7"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "ef3fec65f9db26fa2cf8f4133c697c5b7ce63c1d"
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "6406b5112809c08b1baa5703ad274e1dded0652f"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.22"
version = "0.10.23"

[[Functors]]
git-tree-sha1 = "e4768c3b7f597d5a352afa09874d16e3c3f6ead2"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.2.7"

[[GPLikelihoods]]
deps = ["Distributions", "Functors", "LinearAlgebra", "Random", "StatsFuns"]
git-tree-sha1 = "bdfe8a65b3ca3aa92812d74138264570f33aa66e"
deps = ["Distributions", "Functors", "InverseFunctions", "LinearAlgebra", "Random", "StatsFuns"]
git-tree-sha1 = "561e03fc0dc1d38560dc1403ad95b308418f0ed6"
uuid = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
version = "0.2.4"
version = "0.2.5"

[[IOCapture]]
deps = ["Logging", "Random"]
Expand All @@ -157,9 +169,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[InverseFunctions]]
deps = ["Test"]
git-tree-sha1 = "f0c6489b12d28fb4c2103073ec7452f3423bd308"
git-tree-sha1 = "a7254c0acd8e62f1ac75ad24d5db43f5f19f3c65"
uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
version = "0.1.1"
version = "0.1.2"

[[IrrationalConstants]]
git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151"
Expand Down Expand Up @@ -214,10 +226,10 @@ deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[LogExpFunctions]]
deps = ["ChainRulesCore", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "6193c3815f13ba1b78a51ce391db8be016ae9214"
deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "be9eef9f9d78cecb6f262f3c10da151a6c5ab827"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.4"
version = "0.3.5"

[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand Down Expand Up @@ -390,10 +402,10 @@ uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.33.12"

[[StatsFuns]]
deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
git-tree-sha1 = "95072ef1a22b057b1e80f73c2a89ad238ae4cfff"
deps = ["ChainRulesCore", "InverseFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
git-tree-sha1 = "385ab64e64e79f0cd7cfcf897169b91ebbb2d6c8"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.9.12"
version = "0.9.13"

[[SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
Expand Down
35 changes: 31 additions & 4 deletions docs/src/userguide.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,49 @@ To construct a sparse approximation to the exact posterior, we first need to sel
M = 15 # The number of inducing points
z = x[1:M]
```
The inducing inputs `z` imply some latent function values `u = f(z)`, sometimes called pseudo-points. The stochastic variational Gaussian process (SVGP) approximation is defined by a variational distribution `q(u)` over the pseudo-points. In the case of GP regression, the optimal form for `q(u)` is a multivariate Gaussian, which is the only form of `q` currently supported by this package.
The inducing inputs `z` imply some latent function values `u = f(z)`, sometimes called pseudo-points. The [`SparseVariationalApproximation`](@ref) specifies a distribution `q(u)` over the pseudo-points. In the case of GP regression, the optimal form for `q(u)` is a multivariate Gaussian, which is the only form of `q` currently supported by this package.
```julia
using Distributions, LinearAlgebra
q = MvNormal(zeros(length(z)), I)
```
Finally, we pass our `q` along with the inputs `f(z)` to obtain an approximate posterior GP:
```julia
fz = f(z, 1e-6) # 'observe' the process at z with some jitter for numerical stability
approx = SVGP(fz, q) # Instantiate everything needed for the svgp approximation
approx = SparseVariationalApproximation(fz, q) # Instantiate everything needed for the approximation
st-- marked this conversation as resolved.
Show resolved Hide resolved

svgp_posterior = posterior(approx) # Create the approximate posterior
sva_posterior = posterior(approx) # Create the approximate posterior
```

## The Evidence Lower Bound (ELBO)
The approximate posterior constructed above will be a very poor approximation, since `q` was simply chosen to have zero mean and covariance `I`. A measure of the quality of the approximation is given by the ELBO. Optimising this term with respect to the parameters of `q` and the inducing input locations `z` will improve the approximation.
```julia
elbo(SVGP(fz, q), fx, y)
elbo(SparseVariationalApproximation(fz, q), fx, y)
```
A detailed example of how to carry out such optimisation is given in [Regression: Sparse Variational Gaussian Process for Stochastic Optimisation with Flux.jl](@ref). For an example of non-conjugate inference, see [Classification: Sparse Variational Approximation for Non-Conjugate Likelihoods with Optim's L-BFGS](@ref).

# Available Parametrizations

Two parametrizations of `q(u)` are presently available: [`Centered`](@ref) and [`NonCentered`](@ref).
The `Centered` parametrization expresses `q(u)` directly in terms of its mean and covariance.
The `NonCentered` parametrization instead parametrizes the mean and covariance of
`ε := cholesky(cov(u)).U' \ (u - mean(u))`.
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
These parametrizations are also known respectively as "Unwhitened" and "Whitened".

The choice of parametrization can have a substantial impact on the time it takes for ELBO
optimization to converge, and which parametrization is better in a particular situation is
not generally obvious.
That being said, the `NonCentered` parametrization often converges in fewer iterations, so it is the default --
it is what is used in all of the examples above.

If you require a particular parametrization, simply use the 3-argument version of the
approximation constructor:
```julia
SparseVariationalApproximation(Centered(), fz, q)
SparseVariationalApproximation(NonCentered(), fz, q)
```

For a general discussion around these two parametrizations, see e.g. [^Gorinova].
For a GP-specific discussion, see e.g. section 3.4 of [^Paciorek].

[^Gorinova]: Gorinova, Maria and Moore, Dave and Hoffman, Matthew [Automatic Reparameterisation of Probabilistic Programs](http://proceedings.mlr.press/v119/gorinova20a)
[^Paciorek]: [Paciorek, Christopher Joseph. Nonstationary Gaussian processes for regression and spatial modelling. Diss. Carnegie Mellon University, 2003.](https://www.stat.berkeley.edu/~paciorek/diss/paciorek-thesis.pdf)
15 changes: 12 additions & 3 deletions src/ApproximateGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,18 @@ using ChainRulesCore
using FillArrays
using KLDivergences

using AbstractGPs: AbstractGP, FiniteGP, LatentFiniteGP, ApproxPosteriorGP, At_A, diag_At_A

export SparseVariationalApproximation
using AbstractGPs:
AbstractGP,
FiniteGP,
LatentFiniteGP,
ApproxPosteriorGP,
At_A,
diag_At_A,
Xt_A_X,
Xt_A_Y,
diag_Xt_A_X

export SparseVariationalApproximation, Centered, NonCentered
export DefaultQuadrature, Analytic, GaussHermite, MonteCarlo

include("utils.jl")
Expand Down
Loading