Skip to content

How to deal with an stochastic logdensity_fn? #491

Answered by rlouf
oarriaga asked this question in Q&A
Discussion options

You must be logged in to vote

Have you tried using the lower-level kernel e.g. NUTS and at each step do something like:

density_key, sample_key = jax.random.split(rng_key)
logdensity_fn = ft.partial(stochastic_logdensity, density_key)
new_state, info = kernel(sample_key, logdensity_fn, step_size, inverse_mass_matrix)

You will need to do that for adaptation as well, and can take inspiration from the window adaptation's implementation. This is slightly more cumbersome than the high-level interface but should work fine.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by oarriaga
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants