Skip to content

Adding StatsBase.predict to the API #466

Closed
@sethaxen

Description

@sethaxen

In Turing, StatsBase.predict is overloaded to dispatch on DynamicPPL.Model and MCMCChains.Chains (https://github.com/TuringLang/Turing.jl/blob/d76d914231db0198b99e5ca5d69d80934ee016b3/src/inference/Inference.jl#L532-L564). This effectively does batch prediction, conditioning the model on each draw in the chains and calls rand on the model. We also want to do the same thing for InferenceData (see #465).

It would be convenient if StatsBase.predict was added to the DynamicPPL API. It's already an indirect dependency of this package. As suggested by @devmotion in #465 (comment), its default implementation could be to just call rand for a conditioned model:

StatsBase.predict(rng::AbstractRNG, model::DynamicPPL.Model, x) = rand(rng, condition(model, x))
StatsBase.predict(model::DynamicPPL.Model, x) = predict(Random.default_rng(), model, x)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions