Request: allow pymc.sampling_jax to keep tuning samples #6723
leialorenzo
started this conversation in
Ideas
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi!
I just attended the PyMC office hours and was encouraged to start an issue (and there saw that I should put it up for discussion beforehand)
In some cases (while debugging and checking where models have issues) I am interested in looking into the tuning draws from each chain.
pm.sample()
has thediscard_tuned_samples=False
argument, but I was unable to find the equivalent when usingpm.sampling_jax.sample_numpyro_nuts
.I've tried passing it within the kwargs, and got no error, but also no warning message, and no tuning draws were kept.
I believe Christian Luhmann mentioned how it could be worked around with numpyro, using their
warmup
andrun
iteratively, but maybe it could be implemented withinpymc.sampling_jax
?note: It's the first time I start a discussion of any sort, so do let me know if there's something wrong or missing!
EDIT: i just saw that other people offer ideas, I think the simplest way to implement it from the user's side, is to have the same argument available while using the jax sampler,
pm.sampling_jax.sample_numpyro_nuts(..., discard_tuned_samples=bool )
, keeping the same default True valueBeta Was this translation helpful? Give feedback.
All reactions