@@ -124,75 +124,94 @@ end
124124# #####################
125125# Default Transition #
126126# #####################
127- # Default
128- getstats (t) = nothing
127+ getstats (:: Any ) = NamedTuple ()
129128
129+ # TODO (penelopeysm): Remove this abstract type by converting SGLDTransition,
130+ # SMCTransition, and PGTransition to Turing.Inference.Transition instead.
130131abstract type AbstractTransition end
131132
132- struct Transition{T,F<: AbstractFloat ,S <: Union{ NamedTuple,Nothing} } <: AbstractTransition
133+ struct Transition{T,F<: AbstractFloat ,N <: NamedTuple } <: AbstractTransition
133134 θ:: T
134- lp:: F # TODO : merge `lp` with `stat`
135- stat:: S
136- end
135+ logprior:: F
136+ loglikelihood:: F
137+ stat:: N
138+
139+ """
140+ Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
141+
142+ Construct a new `Turing.Inference.Transition` object using the outputs of a
143+ sampler step.
144+
145+ Here, `vi` represents a VarInfo _for which the appropriate parameters have
146+ already been set_. However, the accumulators (e.g. logp) may in general
147+ have junk contents. The role of this method is to re-evaluate `model` and
148+ thus set the accumulators to the correct values.
149+
150+ `sampler_transition` is the transition object returned by the sampler
151+ itself and is only used to extract statistics of interest.
152+ """
153+ function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , sampler_transition)
154+ vi = DynamicPPL. setaccs!! (
155+ vi,
156+ (
157+ DynamicPPL. ValuesAsInModelAccumulator (true ),
158+ DynamicPPL. LogPriorAccumulator (),
159+ DynamicPPL. LogLikelihoodAccumulator (),
160+ ),
161+ )
162+ _, vi = DynamicPPL. evaluate!! (model, vi)
163+
164+ # Extract all the information we need
165+ vals_as_in_model = DynamicPPL. getacc (vi, Val (:ValuesAsInModel )). values
166+ logprior = DynamicPPL. getlogprior (vi)
167+ loglikelihood = DynamicPPL. getloglikelihood (vi)
168+
169+ # Get additional statistics
170+ stats = getstats (sampler_transition)
171+ return new {typeof(vals_as_in_model),typeof(logprior),typeof(stats)} (
172+ vals_as_in_model, logprior, loglikelihood, stats
173+ )
174+ end
137175
138- Transition (θ, lp) = Transition (θ, lp, nothing )
139- function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , t)
140- # TODO (DPPL0.37/penelopeysm): Fix this
141- θ = getparams (model, vi)
142- lp = getlogjoint_internal (vi)
143- return Transition (θ, lp, getstats (t))
176+ function Transition (
177+ model:: DynamicPPL.Model ,
178+ untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata} ,
179+ sampler_transition,
180+ )
181+ # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
182+ # much faster to convert it to a typed varinfo first, hence this method.
183+ # https://github.com/TuringLang/Turing.jl/issues/2604
184+ return Transition (model, DynamicPPL. typed_varinfo (untyped_vi), sampler_transition)
185+ end
144186end
145187
146- # TODO (DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
147188function metadata (t:: Transition )
148- stat = t. stat
149- if stat === nothing
150- return (lp= t. lp,)
151- else
152- return merge ((lp= t. lp,), stat)
153- end
189+ return merge (
190+ t. stat,
191+ (
192+ lp= t. logprior + t. loglikelihood,
193+ logprior= t. logprior,
194+ loglikelihood= t. loglikelihood,
195+ ),
196+ )
197+ end
198+ function metadata (vi:: AbstractVarInfo )
199+ return (
200+ lp= DynamicPPL. getlogjoint (vi),
201+ logprior= DynamicPPL. getlogp (vi),
202+ loglikelihood= DynamicPPL. getloglikelihood (vi),
203+ )
154204end
155-
156- # TODO (DPPL0.37/penelopeysm): Fix this
157- DynamicPPL. getlogjoint (t:: Transition ) = t. lp
158-
159- # Metadata of VarInfo object
160- # TODO (DPPL0.37/penelopeysm): Add log-prior and log-likelihood terms as well
161- metadata (vi:: AbstractVarInfo ) = (lp= getlogjoint (vi),)
162205
163206# #########################
164207# Chain making utilities #
165208# #########################
166209
167- """
168- getparams(model, t)
169-
170- Return a named tuple of parameters.
171- """
172- getparams (model, t) = t. θ
173- function getparams (model:: DynamicPPL.Model , vi:: DynamicPPL.VarInfo )
174- # NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
175- # Unfortunately, using `invlink` can cause issues in scenarios where the constraints
176- # of the parameters change depending on the realizations. Hence we have to use
177- # `values_as_in_model`, which re-runs the model and extracts the parameters
178- # as they are seen in the model, i.e. in the constrained space. Moreover,
179- # this means that the code below will work both of linked and invlinked `vi`.
180- # Ref: https://github.com/TuringLang/Turing.jl/issues/2195
181- # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
182- return DynamicPPL. values_as_in_model (model, true , deepcopy (vi))
210+ getparams (:: DynamicPPL.Model , t:: AbstractTransition ) = t. θ
211+ function getparams (model:: DynamicPPL.Model , vi:: AbstractVarInfo )
212+ t = Transition (model, vi, nothing )
213+ return getparams (model, t)
183214end
184- function getparams (
185- model:: DynamicPPL.Model , untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
186- )
187- # values_as_in_model is unconscionably slow for untyped VarInfo. It's
188- # much faster to convert it to a typed varinfo before calling getparams.
189- # https://github.com/TuringLang/Turing.jl/issues/2604
190- return getparams (model, DynamicPPL. typed_varinfo (untyped_vi))
191- end
192- function getparams (:: DynamicPPL.Model , :: DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}} )
193- return Dict {VarName,Any} ()
194- end
195-
196215function _params_to_array (model:: DynamicPPL.Model , ts:: Vector )
197216 names_set = OrderedSet {VarName} ()
198217 # Extract the parameter names and values from each transition.
@@ -208,7 +227,6 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
208227 iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
209228 mapreduce (collect, vcat, iters)
210229 end
211-
212230 nms = map (first, nms_and_vs)
213231 vs = map (last, nms_and_vs)
214232 for nm in nms
@@ -224,7 +242,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
224242end
225243
226244function get_transition_extras (ts:: AbstractVector{<:VarInfo} )
227- valmat = reshape ([getlogjoint (t) for t in ts], :, 1 )
245+ valmat = reshape ([DynamicPPL . getlogjoint (t) for t in ts], :, 1 )
228246 return [:lp ], valmat
229247end
230248
@@ -466,16 +484,17 @@ function transitions_from_chain(
466484 chain:: MCMCChains.Chains ;
467485 sampler= DynamicPPL. SampleFromPrior (),
468486)
469- vi = Turing . VarInfo (model)
487+ vi = VarInfo (model)
470488
471489 iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
472490 transitions = map (iters) do (sample_idx, chain_idx)
473491 # Set variables present in `chain` and mark those NOT present in chain to be resampled.
492+ # TODO (DPPL0.37/penelopeysm): Aargh! setval_and_resample!!!! Burn this!!!
474493 DynamicPPL. setval_and_resample! (vi, chain, sample_idx, chain_idx)
475494 model (rng, vi, sampler)
476495
477496 # Convert `VarInfo` into `NamedTuple` and save.
478- Transition (model, vi)
497+ Transition (model, vi, nothing )
479498 end
480499
481500 return transitions
0 commit comments