Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HalfNormalRV JAX implementation #1362

Merged
merged 1 commit into from
Dec 14, 2022

Conversation

theorashid
Copy link
Contributor

@theorashid theorashid commented Dec 12, 2022

This for #1335. The distribution does not exist exactly in jax.random, but it can be easily transformed. I chose loc + jax.random.truncated_normal(rng_key, 0.0, jax.numpy.inf, size, dtype) * scale, but you could equally choose loc + jax.numpy.abs(jax.random.normal(rng_key, size, dtype)) * scale. It's up to the maintainers to decide which they prefer. The tests work but I'm happy to implement more tests if desired.

Might also need to consider key splitting order #1345.

Side note: I'm more familiar with the numpyro/tfp-style HalfNormal distributions, which have a single scale parameter are centred at zero.


Here are a few important guidelines and requirements to check before your PR can be merged:

  • There is an informative high-level description of the changes.
  • The description and/or commit message(s) references the relevant GitHub issue(s).
  • pre-commit is installed and set up.
  • The commit messages follow these guidelines.
  • The commits correspond to relevant logical changes, and there are no commits that fix changes introduced by other commits in the same branch/BR.
  • There are tests covering the changes introduced in the PR.

@rlouf rlouf linked an issue Dec 12, 2022 that may be closed by this pull request
@rlouf rlouf changed the title Add JAX implementation of Half Normal Add HalfNormalRV JAX implementation Dec 12, 2022
@brandonwillard brandonwillard added enhancement New feature or request JAX Involves JAX transpilation random variables Involves random variables and/or sampling labels Dec 12, 2022
@rlouf
Copy link
Member

rlouf commented Dec 13, 2022

This looks good! Waiting for #1345 to be merged. In the meantime, would you mind rebasing your changes and renaming the commit to match the PR title?

theorashid added a commit to theorashid/aesara that referenced this pull request Dec 13, 2022
@theorashid
Copy link
Contributor Author

I rebased and changed the split key to match #1345.

rlouf pushed a commit to theorashid/aesara that referenced this pull request Dec 13, 2022
@rlouf
Copy link
Member

rlouf commented Dec 13, 2022

I cleaned the commits; should be good to merge if the tests pass. Thank you for contributing!

@codecov
Copy link

codecov bot commented Dec 13, 2022

Codecov Report

Merging #1362 (e0b8fb1) into main (94f3f32) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1362   +/-   ##
=======================================
  Coverage   74.35%   74.36%           
=======================================
  Files         177      177           
  Lines       49056    49066   +10     
  Branches    10379    10379           
=======================================
+ Hits        36478    36488   +10     
  Misses      10285    10285           
  Partials     2293     2293           
Impacted Files Coverage Δ
aesara/link/jax/dispatch/random.py 100.00% <100.00%> (ø)

@rlouf rlouf merged commit bfcfe4b into aesara-devs:main Dec 14, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request JAX Involves JAX transpilation random variables Involves random variables and/or sampling
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add JAX implementation for HalfNormalRV
3 participants