Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Splitting before using JAX key #1345

Merged
merged 2 commits into from
Dec 13, 2022
Merged

Conversation

AdrienCorenflos
Copy link
Contributor

@AdrienCorenflos AdrienCorenflos commented Dec 9, 2022

This closes #1344
As discussed offline with @rlouf at the moment the jax splitting logic relies on knowing the internals on the splitting and doesn't follow the JAX best practices. This PR fixes this.

Here are a few important guidelines and requirements to check before your PR can be merged:

  • There is an informative high-level description of the changes.
  • The description and/or commit message(s) references the relevant GitHub issue(s).
  • pre-commit is installed and set up.
  • The commit messages follow these guidelines.
  • The commits correspond to relevant logical changes, and there are no commits that fix changes introduced by other commits in the same branch/BR.
  • There are tests covering the changes introduced in the PR.

@brandonwillard brandonwillard added JAX Involves JAX transpilation random variables Involves random variables and/or sampling labels Dec 9, 2022
@AdrienCorenflos
Copy link
Contributor Author

I ran pre-commit, should be good

@AdrienCorenflos
Copy link
Contributor Author

@rlouf

@rlouf
Copy link
Member

rlouf commented Dec 13, 2022

Should be good to merge once the tests pass. Thanks for spotting this and fixing it!

@codecov
Copy link

codecov bot commented Dec 13, 2022

Codecov Report

Merging #1345 (16b50ba) into main (2434cb4) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1345   +/-   ##
=======================================
  Coverage   74.35%   74.35%           
=======================================
  Files         177      177           
  Lines       49046    49056   +10     
  Branches    10379    10379           
=======================================
+ Hits        36468    36478   +10     
  Misses      10285    10285           
  Partials     2293     2293           
Impacted Files Coverage Δ
aesara/link/jax/dispatch/random.py 100.00% <100.00%> (ø)

@rlouf rlouf merged commit 94f3f32 into aesara-devs:main Dec 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
JAX Involves JAX transpilation random variables Involves random variables and/or sampling
Projects
None yet
Development

Successfully merging this pull request may close these issues.

In JAX random linking splitting should happen before using the key
3 participants