Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revival of MCMCTempering.jl #158

Merged
merged 17 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MCMCTempering"
uuid = "ce233488-44ea-4441-b732-192676ce2298"
authors = ["Harrison Wilde <h.wilde@warwick.ac.uk> and contributors"]
version = "0.3.2"
version = "0.4.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -17,7 +17,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"

[compat]
AbstractMCMC = "3.2, 4"
AbstractMCMC = "5"
ConcreteStructs = "0.2"
Distributions = "0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
Expand Down
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ makedocs(
sitename = "MCMCTempering",
format = Documenter.HTML(),
modules = [MCMCTempering],
pages=["Home" => "index.md", "getting-started.md", "api.md"],
pages = ["Home" => "index.md", "getting-started.md", "api.md"],
warnonly = true
)

# Deply!
Expand Down
11 changes: 11 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ MCMCTempering.swap_step

## Other samplers

To make a sampler work with MCMCTempering.jl, the sampler needs to implement a few methods:

```@docs
MCMCTempering.getparams
MCMCTempering.getlogprob
MCMCTempering.getparams_and_logprob
MCMCTempering.setparams_and_logprob!!
```

Other useful methods are:

```@docs
MCMCTempering.saveall
```
Expand Down
16 changes: 8 additions & 8 deletions docs/src/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ Let's instead try to use a _tempered_ version of `RWMH`. _But_ before we do that

To do that we need to implement two methods. First we need to tell MCMCTempering how to extract the parameters, and potentially the log-probabilities, from a `AdvancedMH.Transition`:

```@docs
```@docs; canonical=false
MCMCTempering.getparams_and_logprob
```

And similarly, we need a way to _update_ the parameters and the log-probabilities of a `AdvancedMH.Transition`:

```@docs
```@docs; canonical=false
MCMCTempering.setparams_and_logprob!!
```

Expand Down Expand Up @@ -159,10 +159,7 @@ using AdvancedHMC: AdvancedHMC
using ForwardDiff: ForwardDiff # for automatic differentation of the logdensity

# Creation of the sampler.
metric = AdvancedHMC.DiagEuclideanMetric(1)
integrator = AdvancedHMC.Leapfrog(0.1)
proposal = AdvancedHMC.StaticTrajectory(integrator, 8)
sampler = AdvancedHMC.HMCSampler(proposal, metric)
sampler = AdvancedHMC.HMC(0.1, 8)
sampler_tempered = MCMCTempering.TemperedSampler(sampler, inverse_temperatures)

# Sample!
Expand All @@ -172,6 +169,7 @@ chain = sample(
target_model, sampler, num_iterations;
chain_type=MCMCChains.Chains,
param_names=["x"],
n_adapts=0, # HACK: need this to make AdvancedHMC.jl happy :/
)
plot(chain, size=figsize)
```
Expand Down Expand Up @@ -205,7 +203,8 @@ chain_tempered_all = sample(
StableRNG(42),
target_model, sampler_tempered, num_iterations;
chain_type=Vector{MCMCChains.Chains},
param_names=["x"]
param_names=["x"],
n_adapts=0, # HACK: need this to make AdvancedHMC.jl happy :/
);
```

Expand Down Expand Up @@ -289,7 +288,8 @@ chain_tempered_all = sample(
StableRNG(42),
target_model, sampler_tempered, num_iterations;
chain_type=Vector{MCMCChains.Chains},
param_names=["x"]
param_names=["x"],
n_adapts=0, # HACK: need this to make AdvancedHMC.jl happy :/
);
```

Expand Down
10 changes: 5 additions & 5 deletions src/samplers/multi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,16 @@ function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::MultiModel,
sampler::MultiSampler;
init_params=nothing,
initial_params=nothing,
kwargs...
)
@assert length(model.models) == length(sampler.samplers) "Number of models and samplers must be equal"

# TODO: Handle `init_params` properly. Make sure that they respect the container-types used in
# TODO: Handle `initial_params` properly. Make sure that they respect the container-types used in
# `MultiModel` and `MultiSampler`.
init_params_multi = initparams(model, init_params)
transition_and_states = asyncmap(model.models, sampler.samplers, init_params_multi) do model, sampler, init_params
AbstractMCMC.step(rng, model, sampler; init_params, kwargs...)
initial_params_multi = initparams(model, initial_params)
transition_and_states = asyncmap(model.models, sampler.samplers, initial_params_multi) do model, sampler, initial_params
AbstractMCMC.step(rng, model, sampler; initial_params, kwargs...)
end

return MultipleTransitions(map(first, transition_and_states)), MultipleStates(map(last, transition_and_states))
Expand Down
10 changes: 5 additions & 5 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
AbstractMCMC = "3.2, 4"
AdvancedHMC = "0.4"
AdvancedMH = "0.7"
Bijectors = "0.10"
AbstractMCMC = "5"
AdvancedHMC = "0.6"
AdvancedMH = "0.8"
Bijectors = "0.13"
Distributions = "0.24, 0.25"
LogDensityProblems = "2"
LogDensityProblemsAD = "1"
MCMCChains = "6"
Turing = "0.24"
Turing = "0.34"
julia = "1"
9 changes: 6 additions & 3 deletions test/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@

@testset "SwapSampler" begin
# SwapSampler without tempering (i.e. in a composition and using `MultiModel`, etc.)
init_params = [[5.0], [5.0]]
initial_params = [[5.0], [5.0]]
mdl1 = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), DistributionLogDensity(Normal(4.9999, 1)))
mdl2 = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), DistributionLogDensity(Normal(5.0001, 1)))
spl1 = RWMH(MvNormal(Zeros(dimension(mdl1)), I))
Expand All @@ -181,7 +181,7 @@

# Sample!
rng = Random.default_rng()
transition, state = AbstractMCMC.step(rng, product_model, spl_full; init_params)
transition, state = AbstractMCMC.step(rng, product_model, spl_full; initial_params)
transitions = typeof(transition)[]

# A bit of warm-up.
Expand Down Expand Up @@ -222,7 +222,10 @@

# Check that means are roughly okay.
params_resolved = map(first ∘ MCMCTempering.getparams, transitions_resolved_inner)
@test vec(mean(params_resolved; dims=2)) ≈ [5.0, 5.0] atol = 0.3
mean_tmp = vec(median(params_resolved; dims=2))
for i in 1:2
@test mean_tmp[i] ≈ 5.0 atol = 0.5
end

# A composition of `SwapSampler` and `MultiSampler` has special `AbstractMCMC.bundle_samples`.
@testset "bundle_samples with Vector" begin
Expand Down
6 changes: 5 additions & 1 deletion test/compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ MCMCTempering.getparams_and_logprob(transition::AdvancedMH.GradientTransition) =
function MCMCTempering.setparams_and_logprob!!(model, transition::AdvancedMH.GradientTransition, params, lp)
# NOTE: We have to re-compute the gradient here because this will be used in the subsequent `step` for
# the MALA sampler.
return AdvancedMH.GradientTransition(params, AdvancedMH.logdensity_and_gradient(model, params)...)
return AdvancedMH.GradientTransition(
params,
AdvancedMH.logdensity_and_gradient(model, params)...,
transition.accepted
)
end

# AdvancedHMC.jl
Expand Down
90 changes: 63 additions & 27 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Several properties of the tempered sampler are tested before returning:
- `adapt_atol`: The absolute tolerance for the check of average swap acceptance rate and target swap acceptance rate. Defaults to `0.05`.
- `mean_swap_rate_bound`: A bound on the acceptance rate of swaps performed, e.g. if set to `0.1` and `compare_mean_swap_rate=≥` then at least 10% of attempted swaps should be accepted. Defaults to `0.1`.
- `compare_mean_swap_rate`: a binary function for comparing average swap rate against `mean_swap_rate_bound`. Defaults to `≥`.
- `init_params`: The initial parameters to use for the sampler. Defaults to `nothing`.
- `initial_params`: The initial parameters to use for the sampler. Defaults to `nothing`.
- `param_names`: The names of the parameters in the chain; used to construct the resulting chain. Defaults to `missing`.
- `progress`: Whether to show a progress bar. Defaults to `false`.
"""
Expand All @@ -41,10 +41,12 @@ function test_and_sample_model(
adapt_target=0.234,
adapt_rtol=0.1,
adapt_atol=0.05,
init_params=nothing,
initial_params=nothing,
param_names=missing,
progress=false,
minimum_roundtrips=nothing
minimum_roundtrips=nothing,
rng=make_rng(),
kwargs...
)
# Make the tempered sampler.
sampler_tempered = tempered(
Expand All @@ -64,8 +66,9 @@ function test_and_sample_model(

# Sample.
samples_tempered = AbstractMCMC.sample(
model, sampler_tempered, num_iterations;
callback=callback, progress=progress, init_params=init_params
rng, model, sampler_tempered, num_iterations;
callback=callback, progress=progress, initial_params=initial_params,
kwargs...
)

if !isnothing(minimum_roundtrips)
Expand Down Expand Up @@ -245,7 +248,7 @@ end
[1.0, 1e-3], # extreme temperatures -> don't exect much swapping to occur
num_iterations=num_iterations,
adapt=false,
init_params=[[0.0], [1000.0]], # initialized far apart
initial_params=[[0.0], [1000.0]], # initialized far apart
# At MOST 1% of swaps should be successful.
mean_swap_rate_bound=0.01,
compare_mean_swap_rate=≤,
Expand Down Expand Up @@ -302,7 +305,7 @@ end
# Set up our sampler with a joint multivariate Normal proposal.
sampler = RWMH(MvNormal(zeros(d), Diagonal(σ_true .^ 2)))
# Sample for the non-tempered model for comparison.
samples = AbstractMCMC.sample(model, sampler, num_iterations; progress=false)
samples = AbstractMCMC.sample(make_rng(), model, sampler, num_iterations; progress=false)
chain = AbstractMCMC.bundle_samples(
samples, MCMCTempering.maybe_wrap_model(model), sampler, samples[1], MCMCChains.Chains
)
Expand All @@ -326,19 +329,41 @@ end
adapt=false,
# Make sure we have _some_ roundtrips.
minimum_roundtrips=10,
rng=make_rng(),
)

compare_chains(chain, chain_tempered, rtol=0.1, compare_ess=true)
# Some swap strategies are not great.
ess_slack_ratio = if swap_strategy isa Union{MCMCTempering.SingleRandomSwap,MCMCTempering.SingleSwap}
0.25
else
0.5
end
compare_chains(chain, chain_tempered, rtol=0.1, compare_ess=true, compare_ess_slack=ess_slack_ratio)
end
end

@testset "Turing.jl" begin
# Let's make a default seed we can `deepcopy` throughout to get reproducible results.
seed = 42

# And let's set the seed explicitly for reproducibility.
Random.seed!(seed)

# Instantiate model.
model_dppl = DynamicPPL.TestUtils.demo_assume_dot_observe()
DynamicPPL.@model function demo_model(x)
s ~ Exponential()
m ~ Normal(0, s)
x .~ Normal(m, s)
end
xs_true = rand(Normal(2, 1), 100)
model_dppl = demo_model(xs_true)

# Move to unconstrained space.
vi = DynamicPPL.VarInfo(model_dppl)
# Move to unconstrained space.
vi = DynamicPPL.link!!(DynamicPPL.VarInfo(model_dppl), model_dppl)
vi = DynamicPPL.link!!(vi, model_dppl)
# Get some initial values in unconstrained space.
init_params = copy(vi[:])
initial_params = rand(length(vi[:]))
# Get the parameter names.
param_names = map(Symbol, DynamicPPL.TestUtils.varnames(model_dppl))
# Get bijector so we can get back to unconstrained space afterwards.
Expand All @@ -349,9 +374,9 @@ end
@testset "Tempering of models" begin
beta = 0.5
model_tempered = MCMCTempering.make_tempered_model(model, beta)
@test logdensity(model_tempered, init_params) ≈ beta * logdensity(model, init_params)
@test last(logdensity_and_gradient(model_tempered, init_params)) ≈
beta .* last(logdensity_and_gradient(model, init_params))
@test logdensity(model_tempered, initial_params) ≈ beta * logdensity(model, initial_params)
@test last(logdensity_and_gradient(model_tempered, initial_params)) ≈
beta .* last(logdensity_and_gradient(model, initial_params))
end

@testset "AdvancedHMC.jl" begin
Expand All @@ -360,12 +385,19 @@ end
# Set up HMC smpler.
initial_ϵ = 0.1
integrator = AdvancedHMC.Leapfrog(initial_ϵ)
proposal = AdvancedHMC.NUTS{AdvancedHMC.MultinomialTS, AdvancedHMC.GeneralisedNoUTurn}(integrator)
proposal = AdvancedHMC.HMCKernel(AdvancedHMC.Trajectory{AdvancedHMC.MultinomialTS}(
integrator, AdvancedHMC.GeneralisedNoUTurn()
))
metric = AdvancedHMC.DiagEuclideanMetric(LogDensityProblems.dimension(model))
sampler_hmc = AdvancedHMC.HMCSampler(proposal, metric)
sampler_hmc = AdvancedHMC.HMCSampler(proposal, metric, AdvancedHMC.NoAdaptation())

# Sample using HMC.
samples_hmc = sample(model, sampler_hmc, num_iterations; init_params=copy(init_params), progress=false)
samples_hmc = sample(
make_rng(seed), model, sampler_hmc, num_iterations;
n_adapts=0, # FIXME(torfjelde): Remove once AHMC.jl has fixed.
initial_params=copy(initial_params),
progress=false
)
chain_hmc = AbstractMCMC.bundle_samples(
samples_hmc, MCMCTempering.maybe_wrap_model(model), sampler_hmc, samples_hmc[1], MCMCChains.Chains;
param_names=param_names,
Expand All @@ -375,8 +407,9 @@ end
# Make sure that we get the "same" result when only using the inverse temperature 1.
sampler_tempered = MCMCTempering.TemperedSampler(sampler_hmc, [1])
chain_tempered = sample(
model, sampler_tempered, num_iterations;
init_params=copy(init_params),
make_rng(seed), model, sampler_tempered, num_iterations;
n_adapts=0, # FIXME(torfjelde): Remove once AHMC.jl has fixed.
initial_params=copy(initial_params),
chain_type=MCMCChains.Chains,
param_names=param_names,
progress=false,
Expand All @@ -398,9 +431,11 @@ end
num_iterations=num_iterations,
adapt=false,
mean_swap_rate_bound=0.1,
init_params=copy(init_params),
initial_params=copy(initial_params),
param_names=param_names,
progress=false
progress=false,
n_adapts=0, # FIXME(torfjelde): Remove once AHMC.jl has fixed.
rng=make_rng(seed),
)
map_parameters!(b, chain_tempered)
compare_chains(
Expand All @@ -421,8 +456,8 @@ end

# Sample using MALA.
chain_mh = AbstractMCMC.sample(
model, sampler_mh, num_iterations;
init_params=copy(init_params),
make_rng(), model, sampler_mh, num_iterations;
initial_params=copy(initial_params),
progress=false,
chain_type=MCMCChains.Chains,
param_names=param_names,
Expand All @@ -432,8 +467,8 @@ end
# Make sure that we get the "same" result when only using the inverse temperature 1.
sampler_tempered = MCMCTempering.TemperedSampler(sampler_mh, [1])
chain_tempered = sample(
model, sampler_tempered, num_iterations;
init_params=copy(init_params),
make_rng(), model, sampler_tempered, num_iterations;
initial_params=copy(initial_params),
chain_type=MCMCChains.Chains,
param_names=param_names,
progress=false,
Expand All @@ -455,8 +490,9 @@ end
num_iterations=num_iterations,
adapt=false,
mean_swap_rate_bound=0.1,
init_params=copy(init_params),
param_names=param_names
initial_params=copy(initial_params),
param_names=param_names,
rng=make_rng(),
)
map_parameters!(b, chain_tempered)

Expand Down
Loading
Loading