Skip to content

Commit

Permalink
move expected_loglik here from ApproximateGPs.jl (#70)
Browse files Browse the repository at this point in the history
Moves `expected_loglik` from ApproximateGPs.jl to GPLikelihoods.jlApproximateGPs.

Changes from the original code in ApproximateGPs.jl:
- rename to `expected_loglikelihood` to be more explicit
- changed signature: argument order is now `quadrature`, `lik`, `q_f`, `y`
- rename `_default_quadrature` to `default_expectation_method`
- rename expectation structs to `DefaultExpectationMethod`, `AnalyticExpectation`, `GaussHermiteExpectation`, `MonteCarloExpectation`; they are not exported from here
- drop abstract base type (was not used for anything)
- drop default argument for GaussHermite / MonteCarlo quadrature points (moved into `DefaultExpectationMethod`)
- GaussHermiteExpectation now stores nodes&weights within struct, resolves #71
- move analytic definitions into files for each likelihood implementation

Note: bumps julia compat to 1.6.

Co-authored-by: willtebbutt <wct23@cam.ac.uk>
  • Loading branch information
st-- and willtebbutt authored Mar 28, 2022
1 parent 932f516 commit 18ae436
Show file tree
Hide file tree
Showing 11 changed files with 286 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
matrix:
version:
- '1'
- '1.3'
- '1.6'
- 'nightly'
os:
- ubuntu-latest
Expand Down
12 changes: 10 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
name = "GPLikelihoods"
uuid = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
authors = ["JuliaGaussianProcesses Team"]
version = "0.3.1"
version = "0.4.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
ChainRulesCore = "1.7"
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
FastGaussQuadrature = "0.4"
Functors = "0.1, 0.2"
InverseFunctions = "0.1.2"
IrrationalConstants = "0.1"
SpecialFunctions = "1, 2"
StatsFuns = "0.9.13"
julia = "1.3"
julia = "1.6"
2 changes: 2 additions & 0 deletions src/GPLikelihoods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ include("links.jl")

# Likelihoods
abstract type AbstractLikelihood end

include("expectations.jl")
include("likelihoods/bernoulli.jl")
include("likelihoods/categorical.jl")
include("likelihoods/gaussian.jl")
Expand Down
117 changes: 117 additions & 0 deletions src/expectations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
using FastGaussQuadrature: gausshermite
using SpecialFunctions: loggamma
using ChainRulesCore: ChainRulesCore
using IrrationalConstants: sqrt2, invsqrtπ

struct DefaultExpectationMethod end

struct AnalyticExpectation end

struct GaussHermiteExpectation
xs::Vector{Float64}
ws::Vector{Float64}
end
GaussHermiteExpectation(n::Integer) = GaussHermiteExpectation(gausshermite(n)...)

ChainRulesCore.@non_differentiable gausshermite(n)

struct MonteCarloExpectation
n_samples::Int
end

default_expectation_method(_) = GaussHermiteExpectation(20)

"""
expected_loglikelihood(
quadrature,
lik,
q_f::AbstractVector{<:Normal},
y::AbstractVector,
)
This function computes the expected log likelihood:
```math
∫ q(f) log p(y | f) df
```
where `p(y | f)` is the process likelihood. This is described by `lik`, which should be a
callable that takes `f` as input and returns a Distribution over `y` that supports
`loglikelihood(lik(f), y)`.
`q(f)` is an approximation to the latent function values `f` given by:
```math
q(f) = ∫ p(f | u) q(u) du
```
where `q(u)` is the variational distribution over inducing points (see
[`elbo`](@ref)). The marginal distributions of `q(f)` are given by `q_f`.
`quadrature` determines which method is used to calculate the expected log
likelihood - see [`elbo`](@ref) for more details.
# Extended help
`q(f)` is assumed to be an `MvNormal` distribution and `p(y | f)` is assumed to
have independent marginals such that only the marginals of `q(f)` are required.
"""
expected_loglikelihood(quadrature, lik, q_f, y)

"""
expected_loglikelihood(::DefaultExpectationMethod, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector)
The expected log likelihood, using the default quadrature method for the given likelihood.
(The default quadrature method is defined by `default_expectation_method(lik)`, and should
be the closed form solution if it exists, but otherwise defaults to Gauss-Hermite
quadrature.)
"""
function expected_loglikelihood(
::DefaultExpectationMethod, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
)
quadrature = default_expectation_method(lik)
return expected_loglikelihood(quadrature, lik, q_f, y)
end

function expected_loglikelihood(
mc::MonteCarloExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
)
# take `n_samples` reparameterised samples
f_μ = mean.(q_f)
fs = f_μ .+ std.(q_f) .* randn(eltype(f_μ), length(q_f), mc.n_samples)
lls = loglikelihood.(lik.(fs), y)
return sum(lls) / mc.n_samples
end

# Compute the expected_loglikelihood over a collection of observations and marginal distributions
function expected_loglikelihood(
gh::GaussHermiteExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
)
# Compute the expectation via Gauss-Hermite quadrature
# using a reparameterisation by change of variable
# (see e.g. en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)
return sum(Broadcast.instantiate(
Broadcast.broadcasted(y, q_f) do yᵢ, q_fᵢ # Loop over every pair
# of marginal distribution q(fᵢ) and observation yᵢ
expected_loglikelihood(gh, lik, q_fᵢ, yᵢ)
end,
))
end

# Compute the expected_loglikelihood for one observation and a marginal distributions
function expected_loglikelihood(gh::GaussHermiteExpectation, lik, q_f::Normal, y)
μ = mean(q_f)
σ̃ = sqrt2 * std(q_f)
return invsqrtπ * sum(Broadcast.instantiate(
Broadcast.broadcasted(gh.xs, gh.ws) do x, w # Loop over every
# pair of Gauss-Hermite point x with weight w
f = σ̃ * x + μ
loglikelihood(lik(f), y) * w
end,
))
end

function expected_loglikelihood(
::AnalyticExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
)
return error(
"No analytic solution exists for $(typeof(lik)). Use `DefaultExpectationMethod`, `GaussHermiteExpectation` or `MonteCarloExpectation` instead.",
)
end
12 changes: 12 additions & 0 deletions src/likelihoods/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,15 @@ ExponentialLikelihood(l=exp) = ExponentialLikelihood(link(l))
(l::ExponentialLikelihood)(f::Real) = Exponential(l.invlink(f))

(l::ExponentialLikelihood)(fs::AbstractVector{<:Real}) = Product(map(l, fs))

function expected_loglikelihood(
::AnalyticExpectation,
::ExponentialLikelihood{ExpLink},
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
f_μ = mean.(q_f)
return sum(-f_μ - y .* exp.((var.(q_f) / 2) .- f_μ))
end

default_expectation_method(::ExponentialLikelihood{ExpLink}) = AnalyticExpectation()
15 changes: 15 additions & 0 deletions src/likelihoods/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,18 @@ GammaLikelihood(α::Real=1.0, l=exp) = GammaLikelihood(α, link(l))
(l::GammaLikelihood)(f::Real) = Gamma(l.α, l.invlink(f))

(l::GammaLikelihood)(fs::AbstractVector{<:Real}) = Product(map(l, fs))

function expected_loglikelihood(
::AnalyticExpectation,
lik::GammaLikelihood{ExpLink},
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
f_μ = mean.(q_f)
return sum(
(lik.α - 1) * log.(y) .- y .* exp.((var.(q_f) / 2) .- f_μ) .- lik.α * f_μ .-
loggamma(lik.α),
)
end

default_expectation_method(::GammaLikelihood{ExpLink}) = AnalyticExpectation()
13 changes: 13 additions & 0 deletions src/likelihoods/gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ GaussianLikelihood(σ²::Real) = GaussianLikelihood([σ²])

(l::GaussianLikelihood)(fs::AbstractVector{<:Real}) = MvNormal(fs, first(l.σ²) * I)

function expected_loglikelihood(
::AnalyticExpectation,
lik::GaussianLikelihood,
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
return sum(
-0.5 * (log(2π) .+ log.(lik.σ²) .+ ((y .- mean.(q_f)) .^ 2 .+ var.(q_f)) / lik.σ²)
)
end

default_expectation_method(::GaussianLikelihood) = AnalyticExpectation()

"""
HeteroscedasticGaussianLikelihood(l=exp)
Expand Down
12 changes: 12 additions & 0 deletions src/likelihoods/poisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,15 @@ PoissonLikelihood(l=exp) = PoissonLikelihood(link(l))
(l::PoissonLikelihood)(f::Real) = Poisson(l.invlink(f))

(l::PoissonLikelihood)(fs::AbstractVector{<:Real}) = Product(map(l, fs))

function expected_loglikelihood(
::AnalyticExpectation,
::PoissonLikelihood{ExpLink},
q_f::AbstractVector{<:Normal},
y::AbstractVector{<:Real},
)
f_μ = mean.(q_f)
return sum((y .* f_μ) - exp.(f_μ .+ (var.(q_f) / 2)) - loggamma.(y .+ 1))
end

default_expectation_method(::PoissonLikelihood{ExpLink}) = AnalyticExpectation()
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
Expand Down
100 changes: 100 additions & 0 deletions test/expectations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
@testset "expectations" begin
# Test that the various methods of computing expectations return the same
# result.
rng = MersenneTwister(123456)
q_f = Normal.(zeros(10), ones(10))

likelihoods_to_test = [
ExponentialLikelihood(),
GammaLikelihood(),
PoissonLikelihood(),
GaussianLikelihood(),
]

@testset "testing all analytic implementations" begin
# Test that we're not missing any analytic implementation in `likelihoods_to_test`!
implementation_types = [
(; quadrature=m.sig.types[2], lik=m.sig.types[3]) for
m in methods(GPLikelihoods.expected_loglikelihood)
]
analytic_likelihoods = [
m.lik for m in implementation_types if
m.quadrature == GPLikelihoods.AnalyticExpectation && m.lik != Any
]
for lik_type in analytic_likelihoods
lik_type_instances = filter(lik -> isa(lik, lik_type), likelihoods_to_test)
@test !isempty(lik_type_instances)
lik = first(lik_type_instances)
@test GPLikelihoods.default_expectation_method(lik) isa
GPLikelihoods.AnalyticExpectation
end
end

@testset "$(nameof(typeof(lik)))" for lik in likelihoods_to_test
methods = [
GaussHermiteExpectation(100),
MonteCarloExpectation(1e7),
GPLikelihoods.DefaultExpectationMethod(),
]
def = GPLikelihoods.default_expectation_method(lik)
if def isa GPLikelihoods.AnalyticExpectation
push!(methods, def)
end
y = rand.(rng, lik.(zeros(10)))

results = map(m -> GPLikelihoods.expected_loglikelihood(m, lik, q_f, y), methods)
@test all(x -> isapprox(x, results[end]; atol=1e-6, rtol=1e-3), results)
end

@test GPLikelihoods.expected_loglikelihood(
MonteCarloExpectation(1), GaussianLikelihood(), q_f, zeros(10)
) isa Real
@test GPLikelihoods.expected_loglikelihood(
GaussHermiteExpectation(1), GaussianLikelihood(), q_f, zeros(10)
) isa Real
@test GPLikelihoods.default_expectation_method-> Normal(0, θ)) isa
GaussHermiteExpectation

# see https://github.com/JuliaGaussianProcesses/ApproximateGPs.jl/issues/82
@testset "testing Zygote compatibility with GaussHermiteExpectation" begin
N = 10
gh = GaussHermiteExpectation(12)
μs = randn(rng, N)
σs = rand(rng, N)

# Test differentiation with variational parameters
for lik in likelihoods_to_test
y = rand.(rng, lik.(rand.(Normal.(μs, σs))))
gμ, glogσ = Zygote.gradient(μs, log.(σs)) do μ, logσ
GPLikelihoods.expected_loglikelihood(gh, lik, Normal.(μ, exp.(logσ)), y)
end
@test all(isfinite, gμ)
@test all(isfinite, glogσ)
end

# Test differentiation with likelihood parameters
# Test GaussianLikelihood parameter
σ = 1.0
y = randn(rng, N)
glogσ = only(
Zygote.gradient(log(σ)) do x
GPLikelihoods.expected_loglikelihood(
gh, GaussianLikelihood(exp(x)), Normal.(μs, σs), y
)
end,
)
@test isfinite(glogσ)

# Test GammaLikelihood parameter
α = 2.0
y = rand.(rng, Gamma.(α, rand(N)))
glogα = only(
Zygote.gradient(log(α)) do x
GPLikelihoods.expected_loglikelihood(
gh, GammaLikelihood(exp(x)), Normal.(μs, σs), y
)
end,
)
@test isfinite(glogα)
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
using GPLikelihoods
using GPLikelihoods: GaussHermiteExpectation, MonteCarloExpectation
using GPLikelihoods.TestInterface: test_interface
using Test
using Random
using Functors
using Distributions
using StatsFuns
using Zygote

@testset "GPLikelihoods.jl" begin
include("links.jl")
Expand All @@ -17,4 +19,5 @@ using StatsFuns
include("likelihoods/exponential.jl")
include("likelihoods/negativebinomial.jl")
end
include("expectations.jl")
end

2 comments on commit 18ae436

@st--
Copy link
Member Author

@st-- st-- commented on 18ae436 Mar 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/57473

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.0 -m "<description of version>" 18ae4362eda2f2218c6fcdc84051dea306505247
git push origin v0.4.0

Please sign in to comment.