Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GenGammaRV jax implementation #1450

Merged
merged 1 commit into from
Mar 10, 2023

Conversation

FredericWantiez
Copy link
Contributor

@FredericWantiez FredericWantiez commented Feb 21, 2023

Adding a jax implementation of GenGammaRV, closing #1333. Could do better with the test I've setup, but I don't think stats.cramervonmises can take a named scale argument.

Here are a few important guidelines and requirements to check before your PR can be merged:

  • There is an informative high-level description of the changes.
  • The description and/or commit message(s) references the relevant GitHub issue(s).
  • pre-commit is installed and set up.
  • The commit messages follow these guidelines.
  • The commits correspond to relevant logical changes, and there are no commits that fix changes introduced by other commits in the same branch/BR.
  • There are tests covering the changes introduced in the PR.

Don't worry, your PR doesn't need to be in perfect order to submit it. As development progresses and/or reviewers request changes, you can always rewrite the history of your feature/PR branches.

If your PR is an ongoing effort and you would like to involve us in the process, simply make it a draft PR.

@brandonwillard brandonwillard added enhancement New feature or request JAX Involves JAX transpilation labels Feb 22, 2023
@brandonwillard brandonwillard linked an issue Feb 22, 2023 that may be closed by this pull request
ref = stats.gengamma(alpha / p, p, scale=lam)
ref.random_state = rng_state
samples_ref = ref.rvs(n_samples)
test = stats.cramervonmises_2samp(samples, samples_ref)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is failing in CI because the Python 3.7 tests are using SciPy 1.61 and not > 1.7.0.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably change the SciPy lower bound to 1.7.0, because our NumPy lower bound is 1.17, and SciPy recommends a version greater than 1.7.0 for that. The only relevant SciPy constraint of which I'm aware is in numba-scipy, which currently upper-bounds SciPy at <= 1.7.3, but allows the desired 1.7.0 lower bound.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried and the tests pass locally with SciPy==1.7.2.

Copy link
Member

@rlouf rlouf Feb 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative is to add the test to test_random_RandomVariable which uses stats.cramervonmises.

Copy link
Contributor Author

@FredericWantiez FredericWantiez Feb 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't use stats.cramervonmises in test_random_RandomVariable as it doesn't really allow you to handle named arguments and scipy also uses a different parametrization. I'm happy to change if bumping scipy just for this is not an option

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use a different parametrization in the tests, look at the test for the gamma distribution for instance. Unless I misunderstood your problem?

Copy link
Contributor Author

@FredericWantiez FredericWantiez Feb 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, somehow missed that. I've pushed something to use test_random_RandomVariable but feels brittle to force the loc and scale arguments this way. Some details on the issue with cramervonmisses here: scipy/scipy#18055

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: I've rebased, squashed, and fixed the formatting issue from the last CI run.

@FredericWantiez
Copy link
Contributor Author

Sorry for the spam, rebased the wrong branch. Let me quickly fix it

@FredericWantiez FredericWantiez force-pushed the fred/gengamma branch 4 times, most recently from a6cafef to 7f343b3 Compare March 8, 2023 22:40
@FredericWantiez FredericWantiez marked this pull request as ready for review March 8, 2023 22:50
Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've rebased, adjusted the docstring, and increased the number of samples drawn for the Cramér–von Mises test. This should be good to merge after the tests pass.

@brandonwillard brandonwillard enabled auto-merge (rebase) March 9, 2023 23:21
@brandonwillard brandonwillard merged commit b1fcfc9 into aesara-devs:main Mar 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request JAX Involves JAX transpilation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add JAX implementation for GenGammaRV
3 participants