Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add acceptance rate info to Transition struct #88

Merged
merged 6 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions src/AdvancedMH.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ export
# Reexports
export sample, MCMCThreads, MCMCDistributed, MCMCSerial

# Abstract type for MH-style samplers. Needs better name?
# Abstract type for MH-style samplers. Needs better name?
abstract type MHSampler <: AbstractMCMC.AbstractSampler end

# Abstract type for MH-style transitions.
abstract type AbstractTransition end

# Define a model type. Stores the log density function and the data to
# Define a model type. Stores the log density function and the data to
# evaluate the log density on.
"""
DensityModel{F} <: AbstractModel
Expand All @@ -53,17 +53,19 @@ end

const DensityModelOrLogDensityModel = Union{<:DensityModel,<:AbstractMCMC.LogDensityModel}

# Create a very basic Transition type, only stores the
# parameter draws and the log probability of the draw.
# Create a very basic Transition type, stores the
# parameter draws, the log probability of the draw,
# and the draw information until this point
struct Transition{T,L<:Real} <: AbstractTransition
params :: T
lp :: L
accepted :: Bool
end

# Store the new draw and its log density.
Transition(model::DensityModelOrLogDensityModel, params) = Transition(params, logdensity(model, params))
function Transition(model::AbstractMCMC.LogDensityModel, params)
return Transition(params, LogDensityProblems.logdensity(model.logdensity, params))
# Store the new draw, its log density, and draw information
Transition(model::DensityModelOrLogDensityModel, params, accepted) = Transition(params, logdensity(model, params), accepted)
function Transition(model::AbstractMCMC.LogDensityModel, params, accepted)
return Transition(params, LogDensityProblems.logdensity(model.logdensity, params), accepted)
end

# Calculate the density of the model given some parameterization.
Expand Down Expand Up @@ -128,7 +130,7 @@ function __init__()
if exc.f === logdensity_and_gradient && length(arg_types) == 2 && first(arg_types) <: DensityModel && isempty(kwargs)
print(io, "\\nDid you forget to load ForwardDiff?")
end
end
end
@static if !isdefined(Base, :get_extension)
@require MCMCChains="c7f686f2-ff18-58e9-bc7b-31028e88f75d" include("../ext/AdvancedMHMCMCChainsExt.jl")
@require StructArrays="09ab397b-f2b6-538f-b94a-2f83cf4a842a" include("../ext/AdvancedMHStructArraysExt.jl")
Expand Down
25 changes: 17 additions & 8 deletions src/MALA.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
struct MALA{D} <: MHSampler
proposal::D

MALA{D}(proposal::D) where {D} = new{D}(proposal)
accepted_draws :: Int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember the previous discussion but this design seems a bit surprising to me. I think the number of accepted and total draws is not a property of the sampler but rather of the resulting chain.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the thought here was that samplers have a record of success or failure. This would be useful for implementing delayed rejection or other sample massaging. But that information could probably be retrieved from the vector of transitions along the way.

With that in mind, maybe we can stick with Transition objects just storing whether the sample was accepted or rejected and compute final stats in chain construction or in sampling if needed. I know this PR is already a mess, but I'll revert and make the adjustments.

total_draws :: Int
end

MALA(p,ad,td) = MALA(p,ad,td)

# If we were given a RandomWalkProposal, just use that instead.
MALA(d::RandomWalkProposal) = MALA{typeof(d)}(d)
MALA(d::RandomWalkProposal) = MALA{typeof(d)}(d,0,0)

# Create a RandomWalkProposal if we weren't given one already.
MALA(d) = MALA(RandomWalkProposal(d))
Expand All @@ -15,13 +17,14 @@ struct GradientTransition{T<:Union{Vector, Real, NamedTuple}, L<:Real, G<:Union{
params::T
lp::L
gradient::G
accepted :: Bool
end

logdensity(model::DensityModelOrLogDensityModel, t::GradientTransition) = t.lp

propose(rng::Random.AbstractRNG, ::MALA, model) = error("please specify initial parameters")
function transition(sampler::MALA, model::DensityModelOrLogDensityModel, params)
return GradientTransition(params, logdensity_and_gradient(model, params)...)
function transition(sampler::MALA, model::DensityModelOrLogDensityModel, params, accepted)
return GradientTransition(params, logdensity_and_gradient(model, params)..., accepted)
end

check_capabilities(model::DensityModelOrLogDensityModel) = nothing
Expand All @@ -44,7 +47,7 @@ function AbstractMCMC.step(
kwargs...
)
check_capabilities(model)

# Extract value and gradient of the log density of the current state.
state = transition_prev.params
logdensity_state = transition_prev.lp
Expand All @@ -69,10 +72,16 @@ function AbstractMCMC.step(

# Decide whether to return the previous params or the new one.
transition = if -Random.randexp(rng) < logα
GradientTransition(candidate, logdensity_candidate, gradient_logdensity_candidate)
accepted_draws = sampler.accepted_draws + 1
GradientTransition(candidate, logdensity_candidate, gradient_logdensity_candidate, true)
else
transition_prev
accepted_draws = sampler.accepted_draws
candidate = transition_prev.params
lp = transition_prev.lp
gradient = transition_prev.gradient
GradientTransition(candidate, lp, gradient, false)
end
sampler = MALA(sampler.proposal, accepted_draws, sampler.total_draws+1)
cnrrobertson marked this conversation as resolved.
Show resolved Hide resolved

return transition, transition
end
Expand Down
42 changes: 27 additions & 15 deletions src/emcee.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
struct Ensemble{D} <: MHSampler
n_walkers::Int
proposal::D
accepted_draws :: Int
total_draws :: Int
end

function transition(sampler::Ensemble, model::DensityModelOrLogDensityModel, params)
return [Transition(model, x) for x in params]
Ensemble(n,p) = Ensemble{typeof(p)}(n,p,0,0)

function transition(sampler::Ensemble, model::DensityModelOrLogDensityModel, params, accepted)
return [Transition(model, x, accepted) for x in params]
end

# Define the other sampling steps.
Expand Down Expand Up @@ -68,7 +72,7 @@ end
StretchProposal(p) = StretchProposal(p, 2.0)

function move(
rng::Random.AbstractRNG,
rng::Random.AbstractRNG,
spl::Ensemble{<:StretchProposal},
model::DensityModelOrLogDensityModel,
walker::Transition,
Expand All @@ -84,17 +88,25 @@ function move(
# Make new parameters
y = @. other_walker.params + z * (walker.params - other_walker.params)

# Construct a new walker
new_walker = Transition(model, y)
# Construct a temporary new walker
temp_walker = Transition(model, y, true)

# Calculate accept/reject value.
alpha = alphamult + new_walker.lp - walker.lp
alpha = alphamult + temp_walker.lp - walker.lp

if -Random.randexp(rng) <= alpha
accepted_draws = spl.accepted_draws + 1
new_walker = Transition(model, y, true)
return new_walker
else
return walker
accepted_draws = spl.accepted_draws
params = walker.params
lp = walker.lp
old_walker = Transition(params, lp, false)
return old_walker
end

spl = Ensemble(spl.n_walkers, spl.proposal, accepted_draws, spl.total_draws+1)
end

#########################
Expand Down Expand Up @@ -124,7 +136,7 @@ end

# theta_min = theta - 2.0*π
# theta_max = theta

# f = walker.params
# while true
# stheta, ctheta = sincos(theta)
Expand All @@ -136,15 +148,15 @@ end
# if new_walker.lp > y
# return new_walker
# else
# if theta < 0
# if theta < 0
# theta_min = theta
# else
# theta_max = theta
# end

# theta = theta_min + (theta_max - theta_min) * rand()
# end
# end
# end
# end

#####################
Expand Down Expand Up @@ -180,15 +192,15 @@ end

# theta_min = theta - 2.0*π
# theta_max = theta

# f = walker.params

# i = 0
# while true
# i += 1

# stheta, ctheta = sincos(theta)

# f_prime = f .* ctheta + nu .* stheta

# new_walker = Transition(model, f_prime)
Expand All @@ -198,13 +210,13 @@ end
# if new_walker.lp > y
# return new_walker
# else
# if theta < 0
# if theta < 0
# theta_min = theta
# else
# theta_max = theta
# end

# theta = theta_min + (theta_max - theta_min) * rand()
# end
# end
# end
# end
28 changes: 19 additions & 9 deletions src/mh-core.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
MetropolisHastings{D}

`MetropolisHastings` has one field, `proposal`.
`MetropolisHastings` has one field, `proposal`.
`proposal` is a `Proposal`, `NamedTuple` of `Proposal`, or `Array{Proposal}` in the shape of your data.
For example, if you wanted the sampler to return a `NamedTuple` with shape

Expand Down Expand Up @@ -38,13 +38,17 @@ none is given, the initial parameters will be drawn from the sampler's proposals
- `param_names` is a vector of strings to be assigned to parameters. This is only
used if `chain_type=Chains`.
- `chain_type` is the type of chain you would like returned to you. Supported
types are `chain_type=Chains` if `MCMCChains` is imported, or
types are `chain_type=Chains` if `MCMCChains` is imported, or
`chain_type=StructArray` if `StructArrays` is imported.
"""
struct MetropolisHastings{D} <: MHSampler
proposal::D
accepted_draws :: Int
total_draws :: Int
end

MetropolisHastings(p) = MetropolisHastings{typeof(p)}(p,0,0)

StaticMH(d) = MetropolisHastings(StaticProposal(d))
StaticMH(d::Int) = MetropolisHastings(StaticProposal(MvNormal(Zeros(d), I)))
RWMH(d) = MetropolisHastings(RandomWalkProposal(d))
Expand All @@ -62,12 +66,12 @@ function propose(
return propose(rng, sampler.proposal, model, transition_prev.params)
end

function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params)
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, accepted)
logdensity = AdvancedMH.logdensity(model, params)
return transition(sampler, model, params, logdensity)
return transition(sampler, model, params, logdensity, accepted)
end
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, logdensity::Real)
return Transition(params, logdensity)
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, logdensity::Real, accepted)
return Transition(params, logdensity, accepted)
end

# Define the first sampling step.
Expand All @@ -81,7 +85,7 @@ function AbstractMCMC.step(
kwargs...
)
params = init_params === nothing ? propose(rng, sampler, model) : init_params
cpfiffer marked this conversation as resolved.
Show resolved Hide resolved
transition = AdvancedMH.transition(sampler, model, params)
transition = AdvancedMH.transition(sampler, model, params, false)
return transition, transition
end

Expand All @@ -106,11 +110,17 @@ function AbstractMCMC.step(

# Decide whether to return the previous params or the new one.
transition = if -Random.randexp(rng) < logα
AdvancedMH.transition(sampler, model, candidate, logdensity_candidate)
accepted_draws = sampler.accepted_draws + 1
AdvancedMH.transition(sampler, model, candidate, logdensity_candidate, true)
else
transition_prev
accepted_draws = sampler.accepted_draws
params = transition_prev.params
lp = transition_prev.lp
Transition(params, lp, false)
end

sampler = typeof(sampler)(sampler.proposal, accepted_draws, sampler.total_draws+1)

return transition, transition
end

Expand Down
Loading