-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test model logp before starting any MCMC chains #4211
Conversation
I'm sorry, but can you do the PR without applying |
I think the opposite has happened - + from .util import (chains_and_samples, dataset_to_point_dict, get_default_varnames, get_untransformed_name,
+ is_transformed_name, update_start_vals) @StephenHogg please see the Python Style guide for this repo |
Sorry about this - had auto-linting on in my GUI and didn't realise. Have a look now, hopefully it's clearer. |
pymc3/sampling.py
Outdated
for chain_start_vals in start: | ||
update_start_vals(chain_start_vals, model.test_point, model) | ||
|
||
start_points = [start] if isinstance(start, dict) else start |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remember correctly, the downstream code treats a list
of start
as "start points for each chain", which could explain your index error.
yeah - the start points object doesn't get used further down, it's just for
checking the initial conditions in a way that makes handling both the case
when it's a dictionary and the case when it's an array easy. `Start` is
still what gets fed into functions later on, hence my confusion haha
…On Tue, Nov 10, 2020, 19:12 Michael Osthege ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In pymc3/sampling.py
<#4211 (comment)>:
> @@ -419,6 +419,29 @@ def sample(
"""
model = modelcontext(model)
+ if start is None:
+ start = model.test_point
+ else:
+ if isinstance(start, dict):
+ update_start_vals(start, model.test_point, model)
+ else:
+ for chain_start_vals in start:
+ update_start_vals(chain_start_vals, model.test_point, model)
+
+ start_points = [start] if isinstance(start, dict) else start
If I remember correctly, the downstream code treats a list of start as
"start points for each chain", which could explain your index error.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#4211 (review)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ADFQOPVG3DUPEL5A43RSBIDSPDYYBANCNFSM4TPH3KGA>
.
|
The error you posted in #4116 could also be a cause of invalid model test points. Could be that not all distributions have tests points. |
Just for clarity - what's the path forward here? Sorry for the bother |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think running this check is generally a good idea, but I think it needs to be put in at the right place. If you look at sampling.py
, this change will get rid of a bunch of downstream code (there are some if start is None
, but I don't think in a great way.
What if we had a _check_start_point
function, and it got called at the end of init_nuts
? I think it would contain model.check_test_point
, and the nice error messages, but it would not do anything to the start
argument passed to it.
Sorry, I just read the attached issue, and it seems like that was steering @StephenHogg to put the changes where they are. Interested to hear if what you and @michaelosthege think! |
As a first time contributor I defer to Michael! :) |
I think Colin is right: The block could easily become its own function. That also makes it easier to test, or improve. |
Ok - are you also saying the new function should be called at the end of |
No, the NUTS initialization often suffers from |
The "main path" logic right now is:
I think Michael's right that Concretely, I'm suggesting
|
@michaelosthege any more thoughts on the above? Would like to make sure I'm clear about what I'm coding up before starting again |
@StephenHogg listen to Colin on this one. He's much more literate in what the NUTS code is actually doing. With those checks in their own function, you can run them before & after NUTS initialization. |
I've shifted this into a function called |
Looking at the test output, it seems like a few other tests (e.g. |
Here's the output I get from pytest at this point, if that helps. Some of these are a bit mystifying, as I'm not sure why I'd be getting a max recursion depth error on a test that I've not touched, for instance. Will push one more change to format the error string a bit more nicely, but after that I think I'm probably stuck for now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks nice! I took a look at most of the the test failures, and they're surprisingly helpful. Feel free to ping again if you need more help, but I think this is close:
- Delete
test_hmc.py:test_nuts_error_reporting
. your check is a better one for the same behavior. test_sampling.py:test_deterministic_of_observed
looks like a flake. let's ignore that and hope it goes away. if it doesn't, make thertol
bigger.test_examples.py::TestLatentOccupancy::test_run
is interesting, and looks like a legit failure you found! In this case, the likelihood is passing parameters in the wrong order. It should be
pm.ZeroInflatedPoisson("y", psi, theta, observed=y)
(note thatpsi
andtheta
are switched). I imagine it was passing because the multipart sampling got everything to a reasonable place.- Two failures in
pymc3/tests/test_step.py
can also be either deleted, or ported to the new exception you throw -- it looks like we have aSamplingError
defined, which may be a good, specific error to raise instead of aValueError
.
The only thing still failing at this point is one test in Edit: the flaky test is also not passing, but that definitely passes locally |
Hi @ColCarroll - the only thing that still fails now is |
This looks great! What if you loosen the tolerances on the test, but also open a bug and mention that it got worse when this PR was merged? That's very strange... I think the last two things are:
|
Codecov Report
@@ Coverage Diff @@
## master #4211 +/- ##
==========================================
- Coverage 88.14% 87.95% -0.19%
==========================================
Files 87 87
Lines 14243 14248 +5
==========================================
- Hits 12554 12532 -22
- Misses 1689 1716 +27
|
Yes, that's what I'm saying - I can either leave the conflict in, in which case I can't merge, or I can resolve the conflict in which case linting fails because there's an unneeded import. It's a Catch-22. |
Co-authored-by: Thomas Wiecki <thomas.wiecki@gmail.com>
Shouldn't be a catch-22 😄 Can you try
Then, in
Change it to
(i.e., choose the current changes, ignore the incoming ones) Then,
for more on |
As before - there's a mysterious new test failure |
Wow, CI got changed under you!
|
This new error doesn't seem to have much to do with the code I wrote? Not sure, though |
can you check if that test passes when you run it locally?
|
@MarcoGorelli passes locally, had to update
|
Wait, all checks have passed now? Maybe the test was flaky? |
Yes, I think so. Thanks @StephenHogg! |
Whew, thanks |
+1 Thanks for sticking with us, @StephenHogg -- this was trickier than expected, but I think it will really improve lots of people's experiences. |
* - Fix regression caused by #4211 * - Add test to make sure jitter is being applied to chains starting points by default * - Import appropriate empty context for python < 3.7 * - Apply black formatting * - Change the second check_start_vals to explicitly run on the newly assigned start variable. * - Improve test documentation and add a new condition * Use monkeypatch for more robust test * - Black formatting, once again...
This PR addresses #4116 - making
find_MAP
andsample
check their starting conditions before running any chains. I probably need to work out what the linting settings this repo uses are because it seems like a fair bit of formatting has changed.