Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gufunc signature to SymbolicRandomVariables #7159

Merged
merged 1 commit into from
Feb 29, 2024

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Feb 16, 2024

Description

This meta-info is necessary to reason about batch dims of SymbolicRandomVariables in the context of pymc-devs/pymc-extras#300

This is probably what we should use for RandomVariables, instead of defining ndims_params and ndim_supp. Those can be properties derived from the gufunc signature.

There is however a limitation with gufunc signatures, which has to do with inputs and outputs that are not tensors, such as RNGs and the size vector which obviously cannot have batch dimensions. For now I am treating those as scalars so in the signature they show up as (), but perhaps it makes sense to deviate a bit from numpy and use [] or None?

For vanilla RandomVariables like Normal. the signature would be None,None,None,(),()->None,(), for the inputs: rng, size, dtype, mu, sigma and outputs: next_rng, draws.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7159.org.readthedocs.build/en/7159/

@ricardoV94 ricardoV94 force-pushed the add_gufunc_signatures branch from 217c3c4 to 0595bcf Compare February 16, 2024 16:01
Copy link

codecov bot commented Feb 16, 2024

Codecov Report

Attention: Patch coverage is 91.30435% with 6 lines in your changes are missing coverage. Please review.

Project coverage is 92.23%. Comparing base (a2988c7) to head (45eb6ef).
Report is 1 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7159      +/-   ##
==========================================
- Coverage   92.28%   92.23%   -0.06%     
==========================================
  Files         101      100       -1     
  Lines       16923    16889      -34     
==========================================
- Hits        15618    15578      -40     
- Misses       1305     1311       +6     
Files Coverage Δ
pymc/distributions/censored.py 100.00% <100.00%> (ø)
pymc/distributions/mixture.py 95.08% <100.00%> (+0.12%) ⬆️
pymc/distributions/multivariate.py 93.81% <100.00%> (+0.02%) ⬆️
pymc/distributions/timeseries.py 94.58% <100.00%> (+0.12%) ⬆️
pymc/distributions/distribution.py 94.29% <87.50%> (-0.76%) ⬇️

... and 2 files with indirect coverage changes

@ricardoV94 ricardoV94 force-pushed the add_gufunc_signatures branch from 0595bcf to 1e78baf Compare February 16, 2024 16:11
@ricardoV94 ricardoV94 force-pushed the add_gufunc_signatures branch from 1e78baf to b0a8a89 Compare February 16, 2024 17:23
@ricardoV94 ricardoV94 force-pushed the add_gufunc_signatures branch from b0a8a89 to 2a12e95 Compare February 18, 2024 19:42
@ricardoV94 ricardoV94 changed the title Add gufunc signature to pre-built CustomSymbolicDistributions Add gufunc signature to SymbolicRandomVariables Feb 18, 2024
@ricardoV94 ricardoV94 force-pushed the add_gufunc_signatures branch 5 times, most recently from 7f26000 to 1a3744c Compare February 18, 2024 20:50
@ricardoV94 ricardoV94 marked this pull request as ready for review February 18, 2024 20:54
@zaxtax
Copy link
Contributor

zaxtax commented Feb 21, 2024

Just for clarity, we don't have any sections of the code that use a tensor of RNG seeds right?

@ricardoV94
Copy link
Member Author

That's not a thing in PyTensor

Copy link
Contributor

@zaxtax zaxtax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some questions, but understand it enough

self.ndims_params = [len(sig) for sig in inputs_sig]
self.ndim_supp = max(len(out_sig) for out_sig in outputs_sig)

if self.ndim_supp is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is ndim_supp going to be deprecated? If so should we include a warning?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be deprecated but I didn't want to decide yet on that. Won't be deprecated without a warning beforehand

@@ -594,7 +629,7 @@ def rv_op(
random: Optional[Callable],
moment: Optional[Callable],
ndim_supp: int,
ndims_params: Optional[Sequence[int]],
ndims_params: Sequence[int],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this no longer Optional while everywhere else ndim becomes optional

Copy link
Member Author

@ricardoV94 ricardoV94 Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users don't call this method, the user facing one dist and __new__ always defines ndim from the optional arguments and then passes it here

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, added nitpicks but no blockers. Maybe the code example in the pm.CustomDist (and/or SymbolicRandomVariable for future contributors) would be the closest.

In general these are great though. Much easier to grok than ndims_support, and potentially more useful.

pymc/distributions/distribution.py Outdated Show resolved Hide resolved
pymc/distributions/distribution.py Show resolved Hide resolved
tests/distributions/test_distribution.py Show resolved Hide resolved
@ricardoV94 ricardoV94 force-pushed the add_gufunc_signatures branch 2 times, most recently from 14e2f05 to d81ec05 Compare February 28, 2024 11:15
@ricardoV94 ricardoV94 force-pushed the add_gufunc_signatures branch from d81ec05 to 45eb6ef Compare February 28, 2024 11:36
@ricardoV94 ricardoV94 merged commit 6252d2e into pymc-devs:main Feb 29, 2024
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants