-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Deterministic for sampling_numpyro #5182
Conversation
Codecov Report
@@ Coverage Diff @@
## main #5182 +/- ##
=======================================
Coverage 78.09% 78.09%
=======================================
Files 88 88
Lines 14176 14172 -4
=======================================
- Hits 11071 11068 -3
+ Misses 3105 3104 -1
|
488a9c7
to
a81c42f
Compare
Thanks @zaxtax! Yeah, the looping is pretty unfortunate and I think will be too slow. Wasn't there a PR that already added that? I vaguely remember something about it. |
Good call! |
pymc/tests/test_sampling_jax.py
Outdated
trace = sample_numpyro_nuts(chains=2, random_seed=1322, keep_untransformed=True) | ||
|
||
assert 8 < trace.posterior["a"].mean() < 11 | ||
assert 4 < trace.posterior["b"].mean() < 6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be a much more strict:
assert 4 < trace.posterior["b"].mean() < 6 | |
assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].value / 2) |
obs_at = aesara.shared(obs, borrow=True, name="obs") | ||
with pm.Model() as model: | ||
a = pm.Normal("a", 0, 1) | ||
b = pm.Deterministic("b", a / 2.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's (also) add a deterministic of a transformed variable? That's often where we find problems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe a
does get transformed here. But maybe we can add a check around that to this test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By transform, I mean a variable like s = pm.Lognormal()
, which has a logTransform
automatically under the hood. I am not referring to the pm.Deterministic
transform
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I forgot that for this test I draw a
from a Normal and not a Uniform which has a transform under the hood. I'll add something here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem! The test looks great now
I left two comments about the tests. I agree with @twiecki that we probably want to find a better solution, if there is one. |
pymc/sampling_jax.py
Outdated
@@ -116,6 +123,7 @@ def sample_numpyro_nuts( | |||
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) | |||
|
|||
logp_fn = get_jaxified_logp(model) | |||
fn = model.fastfn(vars_to_sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you give a more informative name and/or a comment that explains what the purpose of this fn
will be?
cf92e4f
to
5ead708
Compare
I think I've addressed all concerns. |
Is there any way to make the code coverage failure go away? The lines it's complaining out are definitely run by the tests. |
The results of |
Thanks! |
This addresses #5100 and passes tests. My only concern is that this is likely not the fastest way to get deterministic variables as iterating per sample and per chain in pure python takes away much of the speed of this method.