2020│ 1 │ m │ 0.824853 │
2121```
2222"""
23- struct ESS <: InferenceAlgorithm end
23+ struct ESS <: AbstractSampler end
24+
25+ DynamicPPL. initialsampler (:: ESS ) = DynamicPPL. SampleFromPrior ()
26+ update_sample_kwargs (:: ESS , :: Integer , kwargs) = kwargs
27+ get_adtype (:: ESS ) = nothing
28+ requires_unconstrained_space (:: ESS ) = false
2429
2530# always accept in the first step
26- function DynamicPPL. initialstep (
27- rng:: AbstractRNG , model:: Model , spl:: Sampler{<:ESS} , vi:: AbstractVarInfo ; kwargs...
28- )
31+ function AbstractMCMC. step (rng:: AbstractRNG , ldf:: LogDensityFunction , spl:: ESS ; kwargs... )
32+ vi = ldf. varinfo
2933 for vn in keys (vi)
3034 dist = getdist (vi, vn)
3135 EllipticalSliceSampling. isgaussian (typeof (dist)) ||
3236 error (" ESS only supports Gaussian prior distributions" )
3337 end
34- return Transition (model, vi), vi
38+ return Transition (ldf . model, vi), vi
3539end
3640
3741function AbstractMCMC. step (
38- rng:: AbstractRNG , model :: Model , spl:: Sampler{<: ESS} , vi:: AbstractVarInfo ; kwargs...
42+ rng:: AbstractRNG , ldf :: LogDensityFunction , spl:: ESS , vi:: AbstractVarInfo ; kwargs...
3943)
4044 # obtain previous sample
4145 f = vi[:]
@@ -45,14 +49,13 @@ function AbstractMCMC.step(
4549 oldstate = EllipticalSliceSampling. ESSState (f, getlogp (vi), nothing )
4650
4751 # compute next state
52+ # Note: `f_loglikelihood` effectively calculates the log-likelihood (not
53+ # log-joint, despite the use of `LDP.logdensity`) because `tilde_assume` is
54+ # overloaded on `SamplingContext(rng, ESS(), ...)` below.
55+ f_loglikelihood = Base. Fix1 (LogDensityProblems. logdensity, ldf)
4856 sample, state = AbstractMCMC. step (
4957 rng,
50- EllipticalSliceSampling. ESSModel (
51- ESSPrior (model, spl, vi),
52- DynamicPPL. LogDensityFunction (
53- model, vi, DynamicPPL. SamplingContext (spl, DynamicPPL. DefaultContext ())
54- ),
55- ),
58+ EllipticalSliceSampling. ESSModel (ESSPrior (ldf. model, spl, vi), f_loglikelihood),
5659 EllipticalSliceSampling. ESS (),
5760 oldstate,
5861 )
@@ -61,67 +64,57 @@ function AbstractMCMC.step(
6164 vi = DynamicPPL. unflatten (vi, sample)
6265 vi = setlogp!! (vi, state. loglikelihood)
6366
64- return Transition (model, vi), vi
67+ return Transition (ldf . model, vi), vi
6568end
6669
6770# Prior distribution of considered random variable
68- struct ESSPrior{M<: Model ,S <: Sampler{<:ESS} , V<: AbstractVarInfo ,T}
71+ struct ESSPrior{M<: Model ,V<: AbstractVarInfo ,T}
6972 model:: M
70- sampler:: S
73+ sampler:: ESS
7174 varinfo:: V
7275 μ:: T
7376
74- function ESSPrior {M,S, V} (
75- model:: M , sampler:: S , varinfo:: V
76- ) where {M<: Model ,S <: Sampler{<:ESS} , V<: AbstractVarInfo }
77+ function ESSPrior {M,V} (
78+ model:: M , sampler:: ESS , varinfo:: V
79+ ) where {M<: Model ,V<: AbstractVarInfo }
7780 vns = keys (varinfo)
7881 μ = mapreduce (vcat, vns) do vn
7982 dist = getdist (varinfo, vn)
8083 EllipticalSliceSampling. isgaussian (typeof (dist)) ||
8184 error (" [ESS] only supports Gaussian prior distributions" )
8285 DynamicPPL. tovec (mean (dist))
8386 end
84- return new {M,S, V,typeof(μ)} (model, sampler, varinfo, μ)
87+ return new {M,V,typeof(μ)} (model, sampler, varinfo, μ)
8588 end
8689end
8790
88- function ESSPrior (model:: Model , sampler:: Sampler{<: ESS} , varinfo:: AbstractVarInfo )
89- return ESSPrior {typeof(model),typeof(sampler),typeof( varinfo)} (model, sampler, varinfo)
91+ function ESSPrior (model:: Model , sampler:: ESS , varinfo:: AbstractVarInfo )
92+ return ESSPrior {typeof(model),typeof(varinfo)} (model, sampler, varinfo)
9093end
9194
9295# Ensure that the prior is a Gaussian distribution (checked in the constructor)
9396EllipticalSliceSampling. isgaussian (:: Type{<:ESSPrior} ) = true
9497
9598# Only define out-of-place sampling
9699function Base. rand (rng:: Random.AbstractRNG , p:: ESSPrior )
97- sampler = p . sampler
98- varinfo = p . varinfo
99- # TODO : Surely there's a better way of doing this now that we have `SamplingContext`?
100- vns = keys (varinfo)
101- for vn in vns
102- set_flag! (varinfo, vn, " del" )
100+ # TODO (penelopeysm): This is ugly -- need to set 'del' flag because
101+ # otherwise DynamicPPL.SampleWithPrior will just use the existing
102+ # parameters in the varinfo. In general SampleWithPrior etc. need to be
103+ # reworked.
104+ for vn in keys (p . varinfo)
105+ set_flag! (p . varinfo, vn, " del" )
103106 end
104- p. model (rng, varinfo, sampler)
105- return varinfo [:]
107+ _, vi = DynamicPPL . evaluate!! ( p. model, p . varinfo, SamplingContext (rng, p . sampler) )
108+ return vi [:]
106109end
107110
108111# Mean of prior distribution
109112Distributions. mean (p:: ESSPrior ) = p. μ
110113
111- # Evaluate log-likelihood of proposals
112- const ESSLogLikelihood{M<: Model ,S<: Sampler{<:ESS} ,V<: AbstractVarInfo } =
113- DynamicPPL. LogDensityFunction{M,V,<: DynamicPPL.SamplingContext{<:S} ,AD} where {AD}
114-
115- (ℓ:: ESSLogLikelihood )(f:: AbstractVector ) = LogDensityProblems. logdensity (ℓ, f)
116-
117114function DynamicPPL. tilde_assume (
118- rng:: Random.AbstractRNG , :: DefaultContext , :: Sampler{<: ESS} , right, vn, vi
115+ rng:: Random.AbstractRNG , :: DefaultContext , :: ESS , right, vn, vi
119116)
120117 return DynamicPPL. tilde_assume (
121118 rng, LikelihoodContext (), SampleFromPrior (), right, vn, vi
122119 )
123120end
124-
125- function DynamicPPL. tilde_observe (ctx:: DefaultContext , :: Sampler{<:ESS} , right, left, vi)
126- return DynamicPPL. tilde_observe (ctx, SampleFromPrior (), right, left, vi)
127- end
0 commit comments