Closed
Description
In Turing, StatsBase.predict
is overloaded to dispatch on DynamicPPL.Model
and MCMCChains.Chains
(https://github.com/TuringLang/Turing.jl/blob/d76d914231db0198b99e5ca5d69d80934ee016b3/src/inference/Inference.jl#L532-L564). This effectively does batch prediction, conditioning the model on each draw in the chains and calls rand
on the model. We also want to do the same thing for InferenceData
(see #465).
It would be convenient if StatsBase.predict
was added to the DynamicPPL API. It's already an indirect dependency of this package. As suggested by @devmotion in #465 (comment), its default implementation could be to just call rand
for a conditioned model:
StatsBase.predict(rng::AbstractRNG, model::DynamicPPL.Model, x) = rand(rng, condition(model, x))
StatsBase.predict(model::DynamicPPL.Model, x) = predict(Random.default_rng(), model, x)
Metadata
Metadata
Assignees
Labels
No labels