Add HalfNormalRV
JAX implementation
#1362
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This for #1335. The distribution does not exist exactly in
jax.random
, but it can be easily transformed. I choseloc + jax.random.truncated_normal(rng_key, 0.0, jax.numpy.inf, size, dtype) * scale
, but you could equally chooseloc + jax.numpy.abs(jax.random.normal(rng_key, size, dtype)) * scale
. It's up to the maintainers to decide which they prefer. The tests work but I'm happy to implement more tests if desired.Might also need to consider key splitting order #1345.
Side note: I'm more familiar with the numpyro/tfp-style HalfNormal distributions, which have a single
scale
parameter are centred at zero.Here are a few important guidelines and requirements to check before your PR can be merged:
pre-commit
is installed and set up.