Skip to content

Commit

Permalink
move acceptance statistics to MHSampler
Browse files Browse the repository at this point in the history
  • Loading branch information
cnrrobertson committed Sep 19, 2023
1 parent 21e940e commit fcfafaf
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 40 deletions.
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ Documentation for AdvancedMH.jl
## Structs
```@docs
MetropolisHastings
Transition
```

## Functions
Expand Down
12 changes: 5 additions & 7 deletions src/AdvancedMH.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,16 @@ const DensityModelOrLogDensityModel = Union{<:DensityModel,<:AbstractMCMC.LogDen
# Create a very basic Transition type, stores the
# parameter draws, the log probability of the draw,
# and the draw information until this point
struct Transition{T,L<:Real,M<:Bool,D<:Int} <: AbstractTransition
struct Transition{T,L<:Real} <: AbstractTransition
params :: T
lp :: L
accepted :: M
accepted_draws :: D
total_draws :: D
accepted :: Bool
end

# Store the new draw, its log density, and draw information
Transition(model::DensityModelOrLogDensityModel, params, accepted, accepted_draws, total_draws) = Transition(params, logdensity(model, params), accepted, accepted_draws, total_draws)
function Transition(model::AbstractMCMC.LogDensityModel, params, accepted, accepted_draws, total_draws)
return Transition(params, LogDensityProblems.logdensity(model.logdensity, params), accepted, accepted_draws, total_draws)
Transition(model::DensityModelOrLogDensityModel, params, accepted) = Transition(params, logdensity(model, params), accepted)
function Transition(model::AbstractMCMC.LogDensityModel, params, accepted)
return Transition(params, LogDensityProblems.logdensity(model.logdensity, params), accepted)
end

# Calculate the density of the model given some parameterization.
Expand Down
28 changes: 14 additions & 14 deletions src/MALA.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
struct MALA{D} <: MHSampler
proposal::D

MALA{D}(proposal::D) where {D} = new{D}(proposal)
accepted_draws :: Int
total_draws :: Int
end

MALA(p,ad,td) = MALA(p,ad,td)

# If we were given a RandomWalkProposal, just use that instead.
MALA(d::RandomWalkProposal) = MALA{typeof(d)}(d)
MALA(d::RandomWalkProposal) = MALA{typeof(d)}(d,0,0)

# Create a RandomWalkProposal if we weren't given one already.
MALA(d) = MALA(RandomWalkProposal(d))


struct GradientTransition{T<:Union{Vector, Real, NamedTuple}, L<:Real, G<:Union{Vector, Real, NamedTuple},M<:Bool,D<:Int} <: AbstractTransition
struct GradientTransition{T<:Union{Vector, Real, NamedTuple}, L<:Real, G<:Union{Vector, Real, NamedTuple}} <: AbstractTransition
params::T
lp::L
gradient::G
accepted :: M
accepted_draws :: D
total_draws :: D
accepted :: Bool
end

logdensity(model::DensityModelOrLogDensityModel, t::GradientTransition) = t.lp

propose(rng::Random.AbstractRNG, ::MALA, model) = error("please specify initial parameters")
function transition(sampler::MALA, model::DensityModelOrLogDensityModel, params, accepted, accepted_draws, total_draws)
return GradientTransition(params, logdensity_and_gradient(model, params)..., accepted, accepted_draws, total_draws)
function transition(sampler::MALA, model::DensityModelOrLogDensityModel, params, accepted)
return GradientTransition(params, logdensity_and_gradient(model, params)..., accepted)
end

check_capabilities(model::DensityModelOrLogDensityModel) = nothing
Expand Down Expand Up @@ -71,17 +71,17 @@ function AbstractMCMC.step(
logα = logdensity_candidate - logdensity_state + logratio_proposal_density

# Decide whether to return the previous params or the new one.
total_draws = transition_prev.total_draws + 1
transition = if -Random.randexp(rng) < logα
accepted_draws = transition_prev.accepted_draws + 1
GradientTransition(candidate, logdensity_candidate, gradient_logdensity_candidate, true, accepted_draws, total_draws)
accepted_draws = sampler.accepted_draws + 1
GradientTransition(candidate, logdensity_candidate, gradient_logdensity_candidate, true)
else
accepted_draws = transition_prev.accepted_draws
accepted_draws = sampler.accepted_draws
candidate = transition_prev.params
lp = transition_prev.lp
gradient = transition_prev.gradient
GradientTransition(candidate, lp, gradient, false, accepted_draws, total_draws)
GradientTransition(candidate, lp, gradient, false)
end
sampler = MALA(sampler.proposal, accepted_draws, sampler.total_draws+1)

return transition, transition
end
Expand Down
21 changes: 13 additions & 8 deletions src/emcee.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
struct Ensemble{D} <: MHSampler
n_walkers::Int
proposal::D
accepted_draws :: Int
total_draws :: Int
end

function transition(sampler::Ensemble, model::DensityModelOrLogDensityModel, params, accepted, accepted_draws, total_draws)
return [Transition(model, x, accepted, accepted_draws, total_draws) for x in params]
Ensemble(n,p) = Ensemble{typeof(p)}(n,p,0,0)

function transition(sampler::Ensemble, model::DensityModelOrLogDensityModel, params, accepted)
return [Transition(model, x, accepted) for x in params]
end

# Define the other sampling steps.
Expand Down Expand Up @@ -85,23 +89,24 @@ function move(
y = @. other_walker.params + z * (walker.params - other_walker.params)

# Construct a temporary new walker
temp_walker = Transition(model, y, true, walker.accepted_draws, walker.total_draws)
temp_walker = Transition(model, y, true)

# Calculate accept/reject value.
alpha = alphamult + temp_walker.lp - walker.lp

total_draws = walker.total_draws + 1
if -Random.randexp(rng) <= alpha
accepted_draws = walker.accepted_draws + 1
new_walker = Transition(model, y, true, accepted_draws, total_draws)
accepted_draws = spl.accepted_draws + 1
new_walker = Transition(model, y, true)
return new_walker
else
accepted_draws = walker.accepted_draws
accepted_draws = spl.accepted_draws
params = walker.params
lp = walker.lp
old_walker = Transition(params, lp, false, accepted_draws, total_draws)
old_walker = Transition(params, lp, false)
return old_walker
end

spl = Ensemble(spl.n_walkers, spl.proposal, accepted_draws, spl.total_draws+1)
end

#########################
Expand Down
25 changes: 15 additions & 10 deletions src/mh-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ types are `chain_type=Chains` if `MCMCChains` is imported, or
"""
struct MetropolisHastings{D} <: MHSampler
proposal::D
accepted_draws :: Int
total_draws :: Int
end

MetropolisHastings(p) = MetropolisHastings{typeof(p)}(p,0,0)

StaticMH(d) = MetropolisHastings(StaticProposal(d))
StaticMH(d::Int) = MetropolisHastings(StaticProposal(MvNormal(Zeros(d), I)))
RWMH(d) = MetropolisHastings(RandomWalkProposal(d))
Expand All @@ -62,12 +66,12 @@ function propose(
return propose(rng, sampler.proposal, model, transition_prev.params)
end

function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, accepted, accepted_draws, total_draws)
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, accepted)
logdensity = AdvancedMH.logdensity(model, params)
return transition(sampler, model, params, logdensity, accepted, accepted_draws, total_draws)
return transition(sampler, model, params, logdensity, accepted)
end
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, logdensity::Real, accepted, accepted_draws, total_draws)
return Transition(params, logdensity, accepted, accepted_draws, total_draws)
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, logdensity::Real, accepted)
return Transition(params, logdensity, accepted)
end

# Define the first sampling step.
Expand All @@ -81,7 +85,7 @@ function AbstractMCMC.step(
kwargs...
)
params = init_params === nothing ? propose(rng, sampler, model) : init_params
transition = AdvancedMH.transition(sampler, model, params, false, 0, 0)
transition = AdvancedMH.transition(sampler, model, params, false)
return transition, transition
end

Expand All @@ -105,17 +109,18 @@ function AbstractMCMC.step(
logratio_proposal_density(sampler, transition_prev, candidate)

# Decide whether to return the previous params or the new one.
total_draws = transition_prev.total_draws + 1
transition = if -Random.randexp(rng) < logα
accepted_draws = transition_prev.accepted_draws + 1
AdvancedMH.transition(sampler, model, candidate, logdensity_candidate, true, accepted_draws, total_draws)
accepted_draws = sampler.accepted_draws + 1
AdvancedMH.transition(sampler, model, candidate, logdensity_candidate, true)
else
accepted_draws = sampler.accepted_draws
params = transition_prev.params
lp = transition_prev.lp
accepted_draws = transition_prev.accepted_draws
Transition(params, lp, false, accepted_draws, total_draws)
Transition(params, lp, false)
end

sampler = typeof(sampler)(sampler.proposal, accepted_draws, sampler.total_draws+1)

return transition, transition
end

Expand Down

0 comments on commit fcfafaf

Please sign in to comment.