Skip to content

Commit

Permalink
Added extensions for AdvancedHMC.jl and AdvancedMH.jl (#161)
Browse files Browse the repository at this point in the history
* Added extensions for AdvancedHMC.jl and AdvancedMH.jl

* Bump Julia compat entry to 1.9 + CI to 1.10 (which will be LTS soon)

* Relaxed test a bit
  • Loading branch information
torfjelde authored Oct 8, 2024
1 parent 600b82c commit deb9668
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
version:
- '1.7'
- '1.10'
- '1'
os:
- ubuntu-latest
Expand Down
16 changes: 15 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,18 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"

[weakdeps]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"

[extensions]
MCMCTemperingAdvancedHMCExt = ["AdvancedHMC"]
MCMCTemperingAdvancedMHExt = ["AdvancedMH"]

[compat]
AbstractMCMC = "5"
AdvancedHMC = "0.6"
AdvancedMH = "0.8"
ConcreteStructs = "0.2"
Distributions = "0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
Expand All @@ -26,4 +36,8 @@ LogDensityProblems = "2"
MCMCChains = "5, 6"
ProgressLogging = "0.1"
Setfield = "0.7, 0.8, 1"
julia = "1"
julia = "1.9"

[extras]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
19 changes: 19 additions & 0 deletions ext/MCMCTemperingAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module MCMCTemperingAdvancedHMCExt

using MCMCTempering: MCMCTempering, Setfield
using AdvancedHMC: AdvancedHMC

MCMCTempering.getparams_and_logprob(t::AdvancedHMC.Transition) = t.z.θ, t.z.ℓπ.value
MCMCTempering.getparams_and_logprob(state::AdvancedHMC.HMCState) = MCMCTempering.getparams_and_logprob(state.transition)

# TODO: Implement `state_from` instead, to avoid re-computation of gradients if possible.
function MCMCTempering.setparams_and_logprob!!(model, state::AdvancedHMC.HMCState, params, lp)
# NOTE: Need to recompute the gradient because it might be used in the next integration step.
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
return Setfield.@set state.transition.z = AdvancedHMC.phasepoint(
hamiltonian, params, state.transition.z.r;
ℓκ=state.transition.z.ℓκ
)
end

end
18 changes: 5 additions & 13 deletions test/compat.jl → ext/MCMCTemperingAdvancedMHExt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# AdvancedMH.jl
module MCMCTemperingAdvancedMHExt

using MCMCTempering: MCMCTempering, Setfield
using AdvancedMH: AdvancedMH

MCMCTempering.getparams_and_logprob(transition::AdvancedMH.Transition) = transition.params, transition.lp
function MCMCTempering.setparams_and_logprob!!(transition::AdvancedMH.Transition, params, lp)
Setfield.@set! transition.params = params
Expand All @@ -17,16 +21,4 @@ function MCMCTempering.setparams_and_logprob!!(model, transition::AdvancedMH.Gra
)
end

# AdvancedHMC.jl
MCMCTempering.getparams_and_logprob(t::AdvancedHMC.Transition) = t.z.θ, t.z.ℓπ.value
MCMCTempering.getparams_and_logprob(state::AdvancedHMC.HMCState) = MCMCTempering.getparams_and_logprob(state.transition)

# TODO: Implement `state_from` instead, to avoid re-computation of gradients if possible.
function MCMCTempering.setparams_and_logprob!!(model, state::AdvancedHMC.HMCState, params, lp)
# NOTE: Need to recompute the gradient because it might be used in the next integration step.
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
return Setfield.@set state.transition.z = AdvancedHMC.phasepoint(
hamiltonian, params, state.transition.z.r;
ℓκ=state.transition.z.ℓκ
)
end
2 changes: 1 addition & 1 deletion test/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@
params_resolved = map(first MCMCTempering.getparams, transitions_resolved_inner)
mean_tmp = vec(median(params_resolved; dims=2))
for i in 1:2
@test mean_tmp[i] 5.0 atol = 0.5
@test mean_tmp[i] 5.0 atol = 0.6
end

# A composition of `SwapSampler` and `MultiSampler` has special `AbstractMCMC.bundle_samples`.
Expand Down
1 change: 0 additions & 1 deletion test/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,3 @@ using Turing: Turing, DynamicPPL

include("utils.jl")
include("test_utils.jl")
include("compat.jl")

0 comments on commit deb9668

Please sign in to comment.