@@ -124,85 +124,124 @@ end
124124# #####################
125125# Default Transition #
126126# #####################
127- # Default
128- getstats (t) = nothing
127+ getstats (:: Any ) = NamedTuple ()
129128
129+ # TODO (penelopeysm): Remove this abstract type by converting SGLDTransition,
130+ # SMCTransition, and PGTransition to Turing.Inference.Transition instead.
130131abstract type AbstractTransition end
131132
132- struct Transition{T,F<: AbstractFloat ,S <: Union{ NamedTuple,Nothing} } <: AbstractTransition
133+ struct Transition{T,F<: AbstractFloat ,N <: NamedTuple } <: AbstractTransition
133134 θ:: T
134- lp:: F # TODO : merge `lp` with `stat`
135- stat:: S
136- end
137-
138- Transition (θ, lp) = Transition (θ, lp, nothing )
139- function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , t)
140- # TODO (DPPL0.37/penelopeysm): Fix this
141- θ = getparams (model, vi)
142- lp = getlogjoint_internal (vi)
143- return Transition (θ, lp, getstats (t))
135+ logprior:: F
136+ loglikelihood:: F
137+ stat:: N
138+
139+ """
140+ Transition(model::Model, vi::AbstractVarInfo, params::AbstractVector, sampler_transition)
141+
142+ Construct a new `Turing.Inference.Transition` object using the outputs of a sampler step.
143+
144+ Here, `vi` represents a VarInfo which in general may have junk contents (both
145+ parameters and accumulators). The role of this method is to re-evaluate `model` by inserting
146+ the new `params` (provided by the sampler) into the VarInfo `vi`.
147+
148+ `sampler_transition` is the transition object returned by the sampler itself and is only used
149+ to extract statistics of interest.
150+
151+ !!! warning "Parameters must match varinfo linking status"
152+ It is mandatory that the vector of parameters provided line up exactly with how the
153+ VarInfo `vi` is linked. Otherwise, this can silently produce incorrect results.
154+ """
155+ function Transition (
156+ model:: DynamicPPL.Model ,
157+ vi:: AbstractVarInfo ,
158+ parameters:: AbstractVector ,
159+ sampler_transition,
160+ )
161+ # To be safe...
162+ vi = deepcopy (vi)
163+ # Set the parameters and re-evaluate with the appropriate accumulators
164+ vi = DynamicPPL. unflatten (vi, parameters)
165+ vi = DynamicPPL. setaccs!! (
166+ vi,
167+ (
168+ DynamicPPL. ValuesAsInModelAccumulator (true ),
169+ DynamicPPL. LogPriorAccumulator (),
170+ DynamicPPL. LogLikelihoodAccumulator (),
171+ ),
172+ )
173+ _, vi = DynamicPPL. evaluate!! (model, vi)
174+
175+ # Extract all the information we need
176+ vals_as_in_model = DynamicPPL. getacc (vi, Val (:ValuesAsInModel )). values
177+ logprior = DynamicPPL. getlogprior (vi)
178+ loglikelihood = DynamicPPL. getloglikelihood (vi)
179+
180+ # Convert values to the format needed (with individual VarNames split up).
181+ # TODO (penelopeysm): This wouldn't be necessary if not for MCMCChains's poor
182+ # representation...
183+ iters = map (
184+ DynamicPPL. varname_and_value_leaves,
185+ keys (vals_as_in_model),
186+ values (vals_as_in_model),
187+ )
188+ values_split = mapreduce (collect, vcat, iters)
189+
190+ # Get additional statistics
191+ stats = getstats (sampler_transition)
192+ return new {typeof(values_split),typeof(logprior),typeof(stats)} (
193+ values_split, logprior, loglikelihood, stats
194+ )
195+ end
196+ function Transition (
197+ model:: DynamicPPL.Model ,
198+ untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata} ,
199+ parameters:: AbstractVector ,
200+ sampler_transition,
201+ )
202+ # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
203+ # much faster to convert it to a typed varinfo first, hence this method.
204+ # https://github.com/TuringLang/Turing.jl/issues/2604
205+ return Transition (
206+ model, DynamicPPL. typed_varinfo (untyped_vi), parameters, sampler_transition
207+ )
208+ end
144209end
145210
146- # TODO (DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
147211function metadata (t:: Transition )
148- stat = t. stat
149- if stat === nothing
150- return (lp= t. lp,)
151- else
152- return merge ((lp= t. lp,), stat)
153- end
212+ return merge (
213+ t. stat,
214+ (
215+ lp= t. logprior + t. loglikelihood,
216+ logprior= t. logprior,
217+ loglikelihood= t. loglikelihood,
218+ ),
219+ )
220+ end
221+ function metadata (vi:: AbstractVarInfo )
222+ return (
223+ lp= DynamicPPL. getlogjoint (vi),
224+ logprior= DynamicPPL. getlogp (vi),
225+ loglikelihood= DynamicPPL. getloglikelihood (vi),
226+ )
154227end
155-
156- # TODO (DPPL0.37/penelopeysm): Fix this
157- DynamicPPL. getlogjoint (t:: Transition ) = t. lp
158-
159- # Metadata of VarInfo object
160- # TODO (DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
161- metadata (vi:: AbstractVarInfo ) = (lp= getlogjoint (vi),)
162228
163229# #########################
164230# Chain making utilities #
165231# #########################
166232
167- """
168- getparams(model, t)
169-
170- Return a named tuple of parameters.
171- """
172- getparams (model, t) = t. θ
173- function getparams (model:: DynamicPPL.Model , vi:: DynamicPPL.VarInfo )
174- # NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
175- # Unfortunately, using `invlink` can cause issues in scenarios where the constraints
176- # of the parameters change depending on the realizations. Hence we have to use
177- # `values_as_in_model`, which re-runs the model and extracts the parameters
178- # as they are seen in the model, i.e. in the constrained space. Moreover,
179- # this means that the code below will work both of linked and invlinked `vi`.
180- # Ref: https://github.com/TuringLang/Turing.jl/issues/2195
181- # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
182- vals = DynamicPPL. values_as_in_model (model, true , deepcopy (vi))
183-
184- # Obtain an iterator over the flattened parameter names and values.
185- iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
186-
187- # Materialize the iterators and concatenate.
188- return mapreduce (collect, vcat, iters)
233+ getparams (:: DynamicPPL.Model , t:: AbstractTransition ) = t. θ
234+ function getparams (model:: DynamicPPL.Model , vi:: AbstractVarInfo )
235+ t = Transition (model, vi, vi[:], nothing )
236+ return getparams (model, t)
189237end
190- function getparams (
191- model:: DynamicPPL.Model , untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
192- )
193- # values_as_in_model is unconscionably slow for untyped VarInfo. It's
194- # much faster to convert it to a typed varinfo before calling getparams.
195- # https://github.com/TuringLang/Turing.jl/issues/2604
196- return getparams (model, DynamicPPL. typed_varinfo (untyped_vi))
197- end
198- function getparams (:: DynamicPPL.Model , :: DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}} )
199- return float (Real)[]
200- end
201-
202238function _params_to_array (model:: DynamicPPL.Model , ts:: Vector )
203239 names_set = OrderedSet {VarName} ()
204240 # Extract the parameter names and values from each transition.
205241 dicts = map (ts) do t
242+ # TODO (penelopeysm): Get rid of AbstractVarInfo transitions. see
243+ # https://github.com/TuringLang/Turing.jl/issues/2631. That would
244+ # allow us to just use t.θ here.
206245 nms_and_vs = getparams (model, t)
207246 nms = map (first, nms_and_vs)
208247 vs = map (last, nms_and_vs)
@@ -221,7 +260,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
221260end
222261
223262function get_transition_extras (ts:: AbstractVector{<:VarInfo} )
224- valmat = reshape ([getlogjoint (t) for t in ts], :, 1 )
263+ valmat = reshape ([DynamicPPL . getlogjoint (t) for t in ts], :, 1 )
225264 return [:lp ], valmat
226265end
227266
0 commit comments