-
Notifications
You must be signed in to change notification settings - Fork 37
Description
ParamsWithStats is a classic case where we want to re-evaluate the model for correctness, but we don't actually care about the output VarInfo beyond its accumulators.
Implementing this method, rather than using unflatten on a full VarInfo and then re-evaluating, would make things faster because the current method involves a deepcopy which is quite costly (https://github.com/TuringLang/Turing.jl/blob/905f7fa79dbcd64497711813c11d4c08136bff33/src/mcmc/Inference.jl#L162-L168).
For example:
using Turing, ProfileView
J = 8
y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
@model function eight_schools_centered(J, y, sigma)
mu ~ Normal(0, 5)
tau ~ truncated(Cauchy(0, 5); lower=0)
theta ~ MvNormal(fill(mu, J), tau^2 * I)
for i in 1:J
y[i] ~ Normal(theta[i], sigma[i])
end
end
model_esc = eight_schools_centered(J, y, sigma)
# set chain_type=Any to avoid MCMCChains / FlexiChains bundle_samples being included.
@profview sample(model_esc, NUTS(; adtype=AutoForwardDiff()), 10000; chain_type=Any, progress=false, verbose=false)
The red block on the right-hand side is precisely deepcopy. And note that this is with NUTS which does loads of evaluations per step, so with something like MH it's bound to be even more significant.
(To be completely transparent, this profile was obtained using FastLDF #1113. So the relative amount of time spent in deepcopy is more than it would be if you were to run this on current main.)