diff --git a/Project.toml b/Project.toml index d3056ac..c932d98 100644 --- a/Project.toml +++ b/Project.toml @@ -16,8 +16,18 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +[weakdeps] +AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" +AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" + +[extensions] +MCMCTemperingAdvancedHMCExt = ["AdvancedHMC"] +MCMCTemperingAdvancedMHExt = ["AdvancedMH"] + [compat] AbstractMCMC = "5" +AdvancedHMC = "0.6" +AdvancedMH = "0.8" ConcreteStructs = "0.2" Distributions = "0.24, 0.25" DocStringExtensions = "0.8, 0.9" @@ -27,3 +37,7 @@ MCMCChains = "5, 6" ProgressLogging = "0.1" Setfield = "0.7, 0.8, 1" julia = "1" + +[extras] +AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" +AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" \ No newline at end of file diff --git a/ext/MCMCTemperingAdvancedHMCExt.jl b/ext/MCMCTemperingAdvancedHMCExt.jl new file mode 100644 index 0000000..86d6a85 --- /dev/null +++ b/ext/MCMCTemperingAdvancedHMCExt.jl @@ -0,0 +1,19 @@ +module MCMCTemperingAdvancedHMCExt + +using MCMCTempering: MCMCTempering, Setfield +using AdvancedHMC: AdvancedHMC + +MCMCTempering.getparams_and_logprob(t::AdvancedHMC.Transition) = t.z.θ, t.z.ℓπ.value +MCMCTempering.getparams_and_logprob(state::AdvancedHMC.HMCState) = MCMCTempering.getparams_and_logprob(state.transition) + +# TODO: Implement `state_from` instead, to avoid re-computation of gradients if possible. +function MCMCTempering.setparams_and_logprob!!(model, state::AdvancedHMC.HMCState, params, lp) + # NOTE: Need to recompute the gradient because it might be used in the next integration step. + hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) + return Setfield.@set state.transition.z = AdvancedHMC.phasepoint( + hamiltonian, params, state.transition.z.r; + ℓκ=state.transition.z.ℓκ + ) +end + +end diff --git a/test/compat.jl b/ext/MCMCTemperingAdvancedMHExt.jl similarity index 54% rename from test/compat.jl rename to ext/MCMCTemperingAdvancedMHExt.jl index 0f670d7..633b259 100644 --- a/test/compat.jl +++ b/ext/MCMCTemperingAdvancedMHExt.jl @@ -1,4 +1,8 @@ -# AdvancedMH.jl +module MCMCTemperingAdvancedMHExt + +using MCMCTempering: MCMCTempering, Setfield +using AdvancedMH: AdvancedMH + MCMCTempering.getparams_and_logprob(transition::AdvancedMH.Transition) = transition.params, transition.lp function MCMCTempering.setparams_and_logprob!!(transition::AdvancedMH.Transition, params, lp) Setfield.@set! transition.params = params @@ -17,16 +21,4 @@ function MCMCTempering.setparams_and_logprob!!(model, transition::AdvancedMH.Gra ) end -# AdvancedHMC.jl -MCMCTempering.getparams_and_logprob(t::AdvancedHMC.Transition) = t.z.θ, t.z.ℓπ.value -MCMCTempering.getparams_and_logprob(state::AdvancedHMC.HMCState) = MCMCTempering.getparams_and_logprob(state.transition) - -# TODO: Implement `state_from` instead, to avoid re-computation of gradients if possible. -function MCMCTempering.setparams_and_logprob!!(model, state::AdvancedHMC.HMCState, params, lp) - # NOTE: Need to recompute the gradient because it might be used in the next integration step. - hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) - return Setfield.@set state.transition.z = AdvancedHMC.phasepoint( - hamiltonian, params, state.transition.z.r; - ℓκ=state.transition.z.ℓκ - ) end diff --git a/test/setup.jl b/test/setup.jl index 90195ca..1e8c531 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -19,4 +19,3 @@ using Turing: Turing, DynamicPPL include("utils.jl") include("test_utils.jl") -include("compat.jl")