How to speed up model fitting of CustomDist? #6661
-
Hello, I'm trying to implement Skew Student-T distribution using def logp_skewt(value, nu, mu, sigma, alpha, *args, **kwargs):
return (
pm.math.log(2) +
pm.logp(pm.StudentT.dist(nu, mu=mu, sigma=sigma), value) +
pm.logcdf(pm.StudentT.dist(nu, mu=mu, sigma=sigma), alpha*value) -
pm.math.log(sigma)
) I am able to sample from this distribution with pm.Model():
pm.CustomDist('target', 1, 0, 3, -10, logp=logp_skewt)
model_trace = pm.sample(
nuts_sampler="numpyro",
draws=2_000,
chains=1,
)
samples = model_trace.posterior.target.to_numpy()
eps = 0.01
min_val, max_val = np.quantile(samples, [eps, 1 - eps])
valid_samples = samples[(samples >= min_val) & (samples <= max_val)] However, when I try to re-fit the model, it became very slow with pm.Model() as fitted_model:
nu = pm.HalfCauchy('nu', beta=1)
mu = pm.Normal('mu', mu=0, sigma=1)
sigma = pm.HalfCauchy('sigma', beta=1)
alpha = pm.Normal('alpha', mu=0, sigma=1)
skewt = pm.CustomDist('likelihood', nu + eps, mu, sigma + eps, alpha, logp=logp_skewt, observed=valid_samples[:1000])
model_trace = pm.sample(
nuts_sampler="pymc",
draws=100,
tune=100,
chains=1,
) There are warnings which are
It took about 16 minutes to finish fitting on
So is there a common way to speedup the computation? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
We are working in speeding up these type of gradients in pymc-devs/pytensor#174 Right now they are implemented in Numpy and can't be compiled to JAX. I will try to push that PR over the finish line sometime in the next weeks. For now, if you want to speed them you might need to re-implement the Ops manually in your target backend which isn't trivial if you are not familiar with PyTensor and/or JAX |
Beta Was this translation helpful? Give feedback.
We are working in speeding up these type of gradients in pymc-devs/pytensor#174
Right now they are implemented in Numpy and can't be compiled to JAX.
I will try to push that PR over the finish line sometime in the next weeks.
For now, if you want to speed them you might need to re-implement the Ops manually in your target backend which isn't trivial if you are not familiar with PyTensor and/or JAX