Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add acceptance rate info to Transition struct #88

Merged
merged 6 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions src/AdvancedMH.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ export
# Reexports
export sample, MCMCThreads, MCMCDistributed, MCMCSerial

# Abstract type for MH-style samplers. Needs better name?
# Abstract type for MH-style samplers. Needs better name?
abstract type MHSampler <: AbstractMCMC.AbstractSampler end

# Abstract type for MH-style transitions.
abstract type AbstractTransition end

# Define a model type. Stores the log density function and the data to
# Define a model type. Stores the log density function and the data to
# evaluate the log density on.
"""
DensityModel{F} <: AbstractModel
Expand All @@ -53,17 +53,19 @@ end

const DensityModelOrLogDensityModel = Union{<:DensityModel,<:AbstractMCMC.LogDensityModel}

# Create a very basic Transition type, only stores the
# parameter draws and the log probability of the draw.
# Create a very basic Transition type, stores the
# parameter draws, the log probability of the draw,
# and the draw information until this point
struct Transition{T,L<:Real} <: AbstractTransition
params :: T
lp :: L
accepted :: Bool
end

# Store the new draw and its log density.
Transition(model::DensityModelOrLogDensityModel, params) = Transition(params, logdensity(model, params))
function Transition(model::AbstractMCMC.LogDensityModel, params)
return Transition(params, LogDensityProblems.logdensity(model.logdensity, params))
# Store the new draw, its log density, and draw information
Transition(model::DensityModelOrLogDensityModel, params, accepted) = Transition(params, logdensity(model, params), accepted)
function Transition(model::AbstractMCMC.LogDensityModel, params, accepted)
return Transition(params, LogDensityProblems.logdensity(model.logdensity, params), accepted)
end

# Calculate the density of the model given some parameterization.
Expand Down Expand Up @@ -128,7 +130,7 @@ function __init__()
if exc.f === logdensity_and_gradient && length(arg_types) == 2 && first(arg_types) <: DensityModel && isempty(kwargs)
print(io, "\\nDid you forget to load ForwardDiff?")
end
end
end
@static if !isdefined(Base, :get_extension)
@require MCMCChains="c7f686f2-ff18-58e9-bc7b-31028e88f75d" include("../ext/AdvancedMHMCMCChainsExt.jl")
@require StructArrays="09ab397b-f2b6-538f-b94a-2f83cf4a842a" include("../ext/AdvancedMHStructArraysExt.jl")
Expand Down
14 changes: 9 additions & 5 deletions src/MALA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ struct GradientTransition{T<:Union{Vector, Real, NamedTuple}, L<:Real, G<:Union{
params::T
lp::L
gradient::G
accepted::Bool
end

logdensity(model::DensityModelOrLogDensityModel, t::GradientTransition) = t.lp

propose(rng::Random.AbstractRNG, ::MALA, model) = error("please specify initial parameters")
function transition(sampler::MALA, model::DensityModelOrLogDensityModel, params)
return GradientTransition(params, logdensity_and_gradient(model, params)...)
function transition(sampler::MALA, model::DensityModelOrLogDensityModel, params, accepted)
return GradientTransition(params, logdensity_and_gradient(model, params)..., accepted)
end

check_capabilities(model::DensityModelOrLogDensityModel) = nothing
Expand All @@ -44,7 +45,7 @@ function AbstractMCMC.step(
kwargs...
)
check_capabilities(model)

# Extract value and gradient of the log density of the current state.
state = transition_prev.params
logdensity_state = transition_prev.lp
Expand All @@ -69,9 +70,12 @@ function AbstractMCMC.step(

# Decide whether to return the previous params or the new one.
transition = if -Random.randexp(rng) < logα
GradientTransition(candidate, logdensity_candidate, gradient_logdensity_candidate)
GradientTransition(candidate, logdensity_candidate, gradient_logdensity_candidate, true)
else
transition_prev
candidate = transition_prev.params
lp = transition_prev.lp
gradient = transition_prev.gradient
GradientTransition(candidate, lp, gradient, false)
end

return transition, transition
Expand Down
34 changes: 19 additions & 15 deletions src/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ struct Ensemble{D} <: MHSampler
proposal::D
end

function transition(sampler::Ensemble, model::DensityModelOrLogDensityModel, params)
return [Transition(model, x) for x in params]
function transition(sampler::Ensemble, model::DensityModelOrLogDensityModel, params, accepted)
return [Transition(model, x, accepted) for x in params]
end

# Define the other sampling steps.
Expand Down Expand Up @@ -68,7 +68,7 @@ end
StretchProposal(p) = StretchProposal(p, 2.0)

function move(
rng::Random.AbstractRNG,
rng::Random.AbstractRNG,
spl::Ensemble{<:StretchProposal},
model::DensityModelOrLogDensityModel,
walker::Transition,
Expand All @@ -84,16 +84,20 @@ function move(
# Make new parameters
y = @. other_walker.params + z * (walker.params - other_walker.params)

# Construct a new walker
new_walker = Transition(model, y)
# Construct a temporary new walker
temp_walker = Transition(model, y, true)

# Calculate accept/reject value.
alpha = alphamult + new_walker.lp - walker.lp
alpha = alphamult + temp_walker.lp - walker.lp

if -Random.randexp(rng) <= alpha
new_walker = Transition(model, y, true)
return new_walker
else
return walker
params = walker.params
lp = walker.lp
old_walker = Transition(params, lp, false)
return old_walker
end
end

Expand Down Expand Up @@ -124,7 +128,7 @@ end

# theta_min = theta - 2.0*π
# theta_max = theta

# f = walker.params
# while true
# stheta, ctheta = sincos(theta)
Expand All @@ -136,15 +140,15 @@ end
# if new_walker.lp > y
# return new_walker
# else
# if theta < 0
# if theta < 0
# theta_min = theta
# else
# theta_max = theta
# end

# theta = theta_min + (theta_max - theta_min) * rand()
# end
# end
# end
# end

#####################
Expand Down Expand Up @@ -180,15 +184,15 @@ end

# theta_min = theta - 2.0*π
# theta_max = theta

# f = walker.params

# i = 0
# while true
# i += 1

# stheta, ctheta = sincos(theta)

# f_prime = f .* ctheta + nu .* stheta

# new_walker = Transition(model, f_prime)
Expand All @@ -198,13 +202,13 @@ end
# if new_walker.lp > y
# return new_walker
# else
# if theta < 0
# if theta < 0
# theta_min = theta
# else
# theta_max = theta
# end

# theta = theta_min + (theta_max - theta_min) * rand()
# end
# end
# end
# end
20 changes: 11 additions & 9 deletions src/mh-core.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
MetropolisHastings{D}

`MetropolisHastings` has one field, `proposal`.
`MetropolisHastings` has one field, `proposal`.
`proposal` is a `Proposal`, `NamedTuple` of `Proposal`, or `Array{Proposal}` in the shape of your data.
For example, if you wanted the sampler to return a `NamedTuple` with shape

Expand Down Expand Up @@ -38,7 +38,7 @@ none is given, the initial parameters will be drawn from the sampler's proposals
- `param_names` is a vector of strings to be assigned to parameters. This is only
used if `chain_type=Chains`.
- `chain_type` is the type of chain you would like returned to you. Supported
types are `chain_type=Chains` if `MCMCChains` is imported, or
types are `chain_type=Chains` if `MCMCChains` is imported, or
`chain_type=StructArray` if `StructArrays` is imported.
"""
struct MetropolisHastings{D} <: MHSampler
Expand All @@ -62,12 +62,12 @@ function propose(
return propose(rng, sampler.proposal, model, transition_prev.params)
end

function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params)
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, accepted)
logdensity = AdvancedMH.logdensity(model, params)
return transition(sampler, model, params, logdensity)
return transition(sampler, model, params, logdensity, accepted)
end
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, logdensity::Real)
return Transition(params, logdensity)
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, logdensity::Real, accepted)
return Transition(params, logdensity, accepted)
end

# Define the first sampling step.
Expand All @@ -81,7 +81,7 @@ function AbstractMCMC.step(
kwargs...
)
params = initial_params === nothing ? propose(rng, sampler, model) : initial_params
transition = AdvancedMH.transition(sampler, model, params)
transition = AdvancedMH.transition(sampler, model, params, false)
return transition, transition
end

Expand All @@ -106,9 +106,11 @@ function AbstractMCMC.step(

# Decide whether to return the previous params or the new one.
transition = if -Random.randexp(rng) < logα
AdvancedMH.transition(sampler, model, candidate, logdensity_candidate)
AdvancedMH.transition(sampler, model, candidate, logdensity_candidate, true)
else
transition_prev
params = transition_prev.params
lp = transition_prev.lp
Transition(params, lp, false)
end

return transition, transition
Expand Down
Loading