1- # InferenceAlgorithm interface
1+ # AbstractSampler interface for Turing
22
3- abstract type Hamiltonian <: InferenceAlgorithm end
3+ abstract type Hamiltonian <: AbstractMCMC.AbstractSampler end
44
5- DynamicPPL. initialsampler (:: Sampler{<: Hamiltonian} ) = SampleFromUniform ()
5+ DynamicPPL. initialsampler (:: Hamiltonian ) = DynamicPPL . SampleFromUniform ()
66requires_unconstrained_space (:: Hamiltonian ) = true
77# TODO (penelopeysm): This is really quite dangerous code because it implicitly
88# assumes that any concrete type that subtypes `Hamiltonian` has an adtype
152152function AbstractMCMC. step (
153153 rng:: AbstractRNG ,
154154 ldf:: LogDensityFunction ,
155- spl:: Sampler{<: Hamiltonian} ;
155+ spl:: Hamiltonian ;
156156 initial_params= nothing ,
157157 nadapts= 0 ,
158158 kwargs... ,
@@ -165,7 +165,7 @@ function AbstractMCMC.step(
165165 has_initial_params = initial_params != = nothing
166166
167167 # Create a Hamiltonian.
168- metricT = getmetricT (spl. alg )
168+ metricT = getmetricT (spl)
169169 metric = metricT (length (theta))
170170 lp_func = Base. Fix1 (LogDensityProblems. logdensity, ldf)
171171 lp_grad_func = Base. Fix1 (LogDensityProblems. logdensity_and_gradient, ldf)
@@ -184,23 +184,23 @@ function AbstractMCMC.step(
184184 log_density_old = getlogp (vi)
185185
186186 # Find good eps if not provided one
187- if iszero (spl. alg . ϵ)
187+ if iszero (spl. ϵ)
188188 ϵ = AHMC. find_good_stepsize (rng, hamiltonian, theta)
189189 @info " Found initial step size" ϵ
190190 else
191- ϵ = spl. alg . ϵ
191+ ϵ = spl. ϵ
192192 end
193193
194194 # Generate a kernel.
195- kernel = make_ahmc_kernel (spl. alg , ϵ)
195+ kernel = make_ahmc_kernel (spl, ϵ)
196196
197197 # Create initial transition and state.
198198 # Already perform one step since otherwise we don't get any statistics.
199199 t = AHMC. transition (rng, hamiltonian, kernel, z)
200200
201201 # Adaptation
202- adaptor = AHMCAdaptor (spl. alg , hamiltonian. metric; ϵ= ϵ)
203- if spl. alg isa AdaptiveHamiltonian
202+ adaptor = AHMCAdaptor (spl, hamiltonian. metric; ϵ= ϵ)
203+ if spl isa AdaptiveHamiltonian
204204 hamiltonian, kernel, _ = AHMC. adapt! (
205205 hamiltonian, kernel, adaptor, 1 , nadapts, t. z. θ, t. stat. acceptance_rate
206206 )
224224function AbstractMCMC. step (
225225 rng:: Random.AbstractRNG ,
226226 ldf:: LogDensityFunction ,
227- spl:: Sampler{<: Hamiltonian} ,
227+ spl:: Hamiltonian ,
228228 state:: HMCState ;
229229 nadapts= 0 ,
230230 kwargs... ,
@@ -236,7 +236,7 @@ function AbstractMCMC.step(
236236
237237 # Adaptation
238238 i = state. i + 1
239- if spl. alg isa AdaptiveHamiltonian
239+ if spl isa AdaptiveHamiltonian
240240 hamiltonian, kernel, _ = AHMC. adapt! (
241241 hamiltonian,
242242 state. kernel,
@@ -276,7 +276,7 @@ function get_hamiltonian(model, spl, vi, state, n)
276276 # using leafcontext(model.context) so could we just remove the argument
277277 # entirely?)
278278 DynamicPPL. SamplingContext (spl, DynamicPPL. leafcontext (model. context));
279- adtype= spl. alg . adtype,
279+ adtype= spl. adtype,
280280 )
281281 lp_func = Base. Fix1 (LogDensityProblems. logdensity, ldf)
282282 lp_grad_func = Base. Fix1 (LogDensityProblems. logdensity_and_gradient, ldf)
@@ -441,17 +441,17 @@ getmetricT(::NUTS{<:Any,metricT}) where {metricT} = metricT
441441# #### HMC core functions
442442# ####
443443
444- getstepsize (sampler:: Sampler{<: Hamiltonian} , state) = sampler. alg . ϵ
445- getstepsize (sampler:: Sampler{<: AdaptiveHamiltonian} , state) = AHMC. getϵ (state. adaptor)
444+ getstepsize (sampler:: Hamiltonian , state) = sampler. ϵ
445+ getstepsize (sampler:: AdaptiveHamiltonian , state) = AHMC. getϵ (state. adaptor)
446446function getstepsize (
447- sampler:: Sampler{<: AdaptiveHamiltonian} ,
447+ sampler:: AdaptiveHamiltonian ,
448448 state:: HMCState{TV,TKernel,THam,PhType,AHMC.Adaptation.NoAdaptation} ,
449449) where {TV,TKernel,THam,PhType}
450450 return state. kernel. τ. integrator. ϵ
451451end
452452
453- gen_metric (dim:: Int , spl:: Sampler{<: Hamiltonian} , state) = AHMC. UnitEuclideanMetric (dim)
454- function gen_metric (dim:: Int , spl:: Sampler{<: AdaptiveHamiltonian} , state)
453+ gen_metric (dim:: Int , spl:: Hamiltonian , state) = AHMC. UnitEuclideanMetric (dim)
454+ function gen_metric (dim:: Int , spl:: AdaptiveHamiltonian , state)
455455 return AHMC. renew (state. hamiltonian. metric, AHMC. getM⁻¹ (state. adaptor. pc))
456456end
457457
@@ -476,13 +476,11 @@ end
476476# ###
477477# ### Compiler interface, i.e. tilde operators.
478478# ###
479- function DynamicPPL. assume (
480- rng, :: Sampler{<:Hamiltonian} , dist:: Distribution , vn:: VarName , vi
481- )
479+ function DynamicPPL. assume (rng, :: Hamiltonian , dist:: Distribution , vn:: VarName , vi)
482480 return DynamicPPL. assume (dist, vn, vi)
483481end
484482
485- function DynamicPPL. observe (:: Sampler{<: Hamiltonian} , d:: Distribution , value, vi)
483+ function DynamicPPL. observe (:: Hamiltonian , d:: Distribution , value, vi)
486484 return DynamicPPL. observe (d, value, vi)
487485end
488486
0 commit comments