Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slice stalls during sampling for some hierarchical models #5891

Closed
covertg opened this issue Jun 13, 2022 · 7 comments
Closed

Slice stalls during sampling for some hierarchical models #5891

covertg opened this issue Jun 13, 2022 · 7 comments

Comments

@covertg
Copy link

covertg commented Jun 13, 2022

When trying to use Slice sampler on certain models, it never finishes sampling. Although sampling starts off quickly, it gradually gets slower. Ultimately it seems to just grind to a halt.

This first happened on a weird model with non-differentiable likelihood, but in trying to debug, I found that this occurs in simpler models as well. I haven't been able to identify specifically what it is about the models that freeze up; but I've noticed it happening in hierarchical models where the variance is estimated as well. For example, a model very similar to that in the using shared variables notebook but with a global sigma parameter added freezes up when slice sampling. A naive "8-schools" model, however, does not seem to slow down.

Example model that freezes:

# generate data_df as in example notebook
coords = {"date": df_data.index, "city": df_data.columns}
with pm.Model(coords=coords) as model:
    europe_mean = pm.Normal("europe_mean_temp", mu=15.0, sigma=3.0)
    europe_sigma = pm.HalfNormal("europe_sigma", 2.0)
    city_offset = pm.Normal("city_offset", mu=0.0, sigma=europe_sigma, dims="city")
    city_temperature = pm.Deterministic("city_temperature", europe_mean + city_offset, dims="city")

    data = pm.ConstantData("data", df_data, dims=("date", "city"))
    pm.Normal("likelihood", mu=city_temperature, sigma=0.5, observed=data)

It stalls out on my system and times out on a Google colab instance. Here's the full notebook in colab for data, sampling, and also the simple 8 schools model.

This is all the investigation I've been able to do so far, but let me know if there's anything more that'd be helpful for me to do.

  • PyMC/PyMC3 Version: 4.0.0
  • Aesara/Theano Version: 2.6.6
  • Python Version: 3.10, 3.7 (in Colab)
  • Operating system: Linux
  • How did you install PyMC/PyMC3: pip
@ricardoV94
Copy link
Member

I have the suspicion that the tuning of Slice might be off / dumb. There were some recent changes to the sampler introduced in #5816 so it's also possible that some bug was introduced.

This PR adds some basic sample stats to see how many logp evaluations the step sampler is doing per draw: #5889

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 13, 2022

In concrete, I am a bit suspicious of this change to tuning introduced a long time ago: b988ba9

It does not seem to be doing the same as before at all

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 13, 2022

No, the tuning seems alright, but can sometimes be painfully slow. You may need to tweak the initial w of the Slice sampler manually to get reasonable speed from the beginning.

@ricardoV94
Copy link
Member

Oh you say you observe a slowdown, I would expect it to stabilize over time, and not get slower... Might need to look at the traceplots to see if it is getting stuck somewhere funny

@covertg
Copy link
Author

covertg commented Jun 13, 2022

Right, if it's that tuning is off at first then I would expect it to speed up as it steps, rather than the opposite...

Given that these are hierarchical models with estimated variance it's very possible the posterior curvature might be making trouble for slice sampler. Would we expect it to essentially freeze up in such cases, though?

I can revisit this and make some trace plots later today. Part of the issue though is that for a moderate number of samples (eg 500 tune, 1000 draw) it consistently stalls before sampling ends. For what it's worth, in initial tests I remember that NUTS had no convergence issues. Also, this very plausibly does seem related to the proposal method, as in a few tries I couldn't find a (reasonable) finite value for the iter_limit parameter that was sufficiently large as to avoid a runtime error.

(p.s. Sorry, thumb slipped & did not mean to close.)

@covertg covertg closed this as completed Jun 13, 2022
@covertg covertg reopened this Jun 13, 2022
@ricardoV94
Copy link
Member

ricardoV94 commented Jun 13, 2022

The slice sampler never rejects, so it all depends on how calibrated the proposal w is to the posterior curvature. My best guess is that the Slice sampler is starting at a point far from the representative set and getting tuned to it, and then later finds itself in the right part of the posterior but with the wrong tuning width.

You can try to start the sampler at the mean obtained from NUTS, to see if it fares better, and check what happens to the tuned w in that case vs the slow cases you described.

@covertg
Copy link
Author

covertg commented Jun 21, 2022

@ricardoV94, thanks for the help and patience on the slow response.

Indeed I think my bug report was premature. After some more experimentation, I've realized that my problem was a badly-specified model, rather than slice sampler acting unexpectedly.

In both my original case as well as the simple test case that I posted about, the sampler was struggling with the region of the posterior where the estimated variance parameter was very close to zero. In the test case, this had to do with the structure of the graphical model—I built a model that technically was estimating the "between-city variance" (see above). Moving the variance parameter such that it estimated the "between observation variance" allowed the slice sampler to sample fine:

with pm.Model(coords=coords) as model:
    europe_mean = pm.Normal("europe_mean_temp", mu=15.0, sigma=3.0)
    europe_sigma = pm.HalfNormal("europe_sigma", sigma=5.0)
    city_temperature = pm.Normal("city_temperature", mu=europe_mean, sigma=1.0, dims="city")

    data = pm.ConstantData("data", df_data, dims=("date", "city"))
    pm.Normal("likelihood", mu=city_temperature, sigma=europe_sigma, observed=data)

My interpretation of this is that there isn't much data to give a meaningful nonzero variance estimate in the first model, which led the sampler to get waylaid in the region of the posterior where sigma≈0. Hopefully that logic seems plausible!

In my more complex case, I think it is a similar story with the data I have been using, but I don't have a different way to estimate the variance of a parameter in that case. There's more to figure out there, but for now, I've found that using a prior on the variance which is bounded to not include zero seems to give the slice sampler an easier time (e.g. TruncatedNormal(mu=0.01, lower=0.01)). (I understand this might be a fraught modeling choice, however... [1] [2]) Setting a more informative prior on the parameter itself also helped mitigate the stalls, but not every time.

Closing the issue, unless any of these conclusions seem off. Thanks again!

@covertg covertg closed this as completed Jun 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants