Skip to content

Implement ParamsWithStats(::FastLDF, ::AbstractVector{<:Real}) #1119

@penelopeysm

Description

@penelopeysm

ParamsWithStats is a classic case where we want to re-evaluate the model for correctness, but we don't actually care about the output VarInfo beyond its accumulators.

Implementing this method, rather than using unflatten on a full VarInfo and then re-evaluating, would make things faster because the current method involves a deepcopy which is quite costly (https://github.com/TuringLang/Turing.jl/blob/905f7fa79dbcd64497711813c11d4c08136bff33/src/mcmc/Inference.jl#L162-L168).

For example:

using Turing, ProfileView

J = 8
y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
@model function eight_schools_centered(J, y, sigma)
    mu ~ Normal(0, 5)
    tau ~ truncated(Cauchy(0, 5); lower=0)
    theta ~ MvNormal(fill(mu, J), tau^2 * I)
    for i in 1:J
        y[i] ~ Normal(theta[i], sigma[i])
    end
end
model_esc = eight_schools_centered(J, y, sigma)

# set chain_type=Any to avoid MCMCChains / FlexiChains bundle_samples being included.
@profview sample(model_esc, NUTS(; adtype=AutoForwardDiff()), 10000; chain_type=Any, progress=false, verbose=false)
Image

The red block on the right-hand side is precisely deepcopy. And note that this is with NUTS which does loads of evaluations per step, so with something like MH it's bound to be even more significant.

(To be completely transparent, this profile was obtained using FastLDF #1113. So the relative amount of time spent in deepcopy is more than it would be if you were to run this on current main.)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions