-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add gufunc signature to SymbolicRandomVariables #7159
Conversation
217c3c4
to
0595bcf
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7159 +/- ##
==========================================
- Coverage 92.28% 92.23% -0.06%
==========================================
Files 101 100 -1
Lines 16923 16889 -34
==========================================
- Hits 15618 15578 -40
- Misses 1305 1311 +6
|
0595bcf
to
1e78baf
Compare
1e78baf
to
b0a8a89
Compare
b0a8a89
to
2a12e95
Compare
7f26000
to
1a3744c
Compare
Just for clarity, we don't have any sections of the code that use a tensor of RNG seeds right? |
That's not a thing in PyTensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some questions, but understand it enough
self.ndims_params = [len(sig) for sig in inputs_sig] | ||
self.ndim_supp = max(len(out_sig) for out_sig in outputs_sig) | ||
|
||
if self.ndim_supp is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is ndim_supp
going to be deprecated? If so should we include a warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be deprecated but I didn't want to decide yet on that. Won't be deprecated without a warning beforehand
@@ -594,7 +629,7 @@ def rv_op( | |||
random: Optional[Callable], | |||
moment: Optional[Callable], | |||
ndim_supp: int, | |||
ndims_params: Optional[Sequence[int]], | |||
ndims_params: Sequence[int], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this no longer Optional
while everywhere else ndim becomes optional
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Users don't call this method, the user facing one dist
and __new__
always defines ndim from the optional arguments and then passes it here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, added nitpicks but no blockers. Maybe the code example in the pm.CustomDist
(and/or SymbolicRandomVariable
for future contributors) would be the closest.
In general these are great though. Much easier to grok than ndims_support, and potentially more useful.
14e2f05
to
d81ec05
Compare
d81ec05
to
45eb6ef
Compare
Description
This meta-info is necessary to reason about batch dims of SymbolicRandomVariables in the context of pymc-devs/pymc-extras#300
This is probably what we should use for RandomVariables, instead of defining
ndims_params
andndim_supp
. Those can be properties derived from the gufunc signature.There is however a limitation with gufunc signatures, which has to do with inputs and outputs that are not tensors, such as RNGs and the size vector which obviously cannot have batch dimensions. For now I am treating those as scalars so in the signature they show up as
()
, but perhaps it makes sense to deviate a bit from numpy and use[]
orNone
?For vanilla RandomVariables like Normal. the signature would be
None,None,None,(),()->None,()
, for the inputs: rng, size, dtype, mu, sigma and outputs: next_rng, draws.Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7159.org.readthedocs.build/en/7159/