Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Deterministic for sampling_numpyro #5182

Merged
merged 6 commits into from
Nov 16, 2021

Conversation

zaxtax
Copy link
Contributor

@zaxtax zaxtax commented Nov 14, 2021

This addresses #5100 and passes tests. My only concern is that this is likely not the fastest way to get deterministic variables as iterating per sample and per chain in pure python takes away much of the speed of this method.

@codecov
Copy link

codecov bot commented Nov 14, 2021

Codecov Report

Merging #5182 (5ead708) into main (b9b9efc) will increase coverage by 0.00%.
The diff coverage is 0.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #5182   +/-   ##
=======================================
  Coverage   78.09%   78.09%           
=======================================
  Files          88       88           
  Lines       14176    14172    -4     
=======================================
- Hits        11071    11068    -3     
+ Misses       3105     3104    -1     
Impacted Files Coverage Δ
pymc/sampling_jax.py 0.00% <0.00%> (ø)
pymc/backends/report.py 89.51% <0.00%> (-2.10%) ⬇️

@twiecki
Copy link
Member

twiecki commented Nov 15, 2021

Thanks @zaxtax! Yeah, the looping is pretty unfortunate and I think will be too slow. Wasn't there a PR that already added that? I vaguely remember something about it.

@twiecki
Copy link
Member

twiecki commented Nov 15, 2021

#4427

@zaxtax
Copy link
Contributor Author

zaxtax commented Nov 15, 2021

Good call!

trace = sample_numpyro_nuts(chains=2, random_seed=1322, keep_untransformed=True)

assert 8 < trace.posterior["a"].mean() < 11
assert 4 < trace.posterior["b"].mean() < 6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a much more strict:

Suggested change
assert 4 < trace.posterior["b"].mean() < 6
assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].value / 2)

obs_at = aesara.shared(obs, borrow=True, name="obs")
with pm.Model() as model:
a = pm.Normal("a", 0, 1)
b = pm.Deterministic("b", a / 2.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's (also) add a deterministic of a transformed variable? That's often where we find problems.

Copy link
Contributor Author

@zaxtax zaxtax Nov 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe a does get transformed here. But maybe we can add a check around that to this test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By transform, I mean a variable like s = pm.Lognormal(), which has a logTransform automatically under the hood. I am not referring to the pm.Deterministic transform

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I forgot that for this test I draw a from a Normal and not a Uniform which has a transform under the hood. I'll add something here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem! The test looks great now

@ricardoV94
Copy link
Member

I left two comments about the tests. I agree with @twiecki that we probably want to find a better solution, if there is one.

@@ -116,6 +123,7 @@ def sample_numpyro_nuts(
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)

logp_fn = get_jaxified_logp(model)
fn = model.fastfn(vars_to_sample)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give a more informative name and/or a comment that explains what the purpose of this fn will be?

@zaxtax
Copy link
Contributor Author

zaxtax commented Nov 15, 2021

I think I've addressed all concerns.

@zaxtax
Copy link
Contributor Author

zaxtax commented Nov 16, 2021

Is there any way to make the code coverage failure go away? The lines it's complaining out are definitely run by the tests.

@ricardoV94
Copy link
Member

The results of test_sampling_jax are not added to the coverage report. You can ignore the "failure"

@twiecki twiecki merged commit 44cf8a7 into pymc-devs:main Nov 16, 2021
@twiecki
Copy link
Member

twiecki commented Nov 16, 2021

Thanks!

@zaxtax zaxtax deleted the add_numpyro_deterministic branch November 16, 2021 14:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants