Request: allow pymc.sampling_jax to keep tuning samples #6723
leialorenzo
started this conversation in
Ideas
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi!
I just attended the PyMC office hours and was encouraged to start an issue (and there saw that I should put it up for discussion beforehand)
In some cases (while debugging and checking where models have issues) I am interested in looking into the tuning draws from each chain.
pm.sample()
has thediscard_tuned_samples=False
argument, but I was unable to find the equivalent when usingpm.sampling_jax.sample_numpyro_nuts
.I've tried passing it within the kwargs, and got no error, but also no warning message, and no tuning draws were kept.
I believe Christian Luhmann mentioned how it could be worked around with numpyro, using their
warmup
andrun
iteratively, but maybe it could be implemented withinpymc.sampling_jax
?note: It's the first time I start a discussion of any sort, so do let me know if there's something wrong or missing!
EDIT: i just saw that other people offer ideas, I think the simplest way to implement it from the user's side, is to have the same argument available while using the jax sampler,
pm.sampling_jax.sample_numpyro_nuts(..., discard_tuned_samples=bool )
, keeping the same default True valueBeta Was this translation helpful? Give feedback.
All reactions