-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched MvNormal distribution #5424
Conversation
class BatchedMatrixInverse(Op): | ||
"""Computes the inverse of a matrix. | ||
|
||
`aesara.tensor.nlinalg.matrix_inverse` can only inverse square matrices. | ||
This Op can inverse batches of square matrices. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two questions:
- Is there an existing function in aesara that can inverse batches of square matrices in a vectorized way?
- If no, I have added this
Op
in a new file. Where should this class ideally be placed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
distributions/dist_math.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If numpy has support for batched matrix inversion but not Aesara, we can open an issue there. Here is a similar case: aesara-devs/aesara#791
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the Aesara Op is just doing the same as this one but has an unnecessary ndim check?
We should just open a PR to fix it there
Codecov Report
@@ Coverage Diff @@
## main #5424 +/- ##
==========================================
- Coverage 81.38% 74.74% -6.65%
==========================================
Files 82 83 +1
Lines 14207 14241 +34
==========================================
- Hits 11563 10645 -918
- Misses 2644 3596 +952
|
Thanks for the ping Ricardo. |
Thank your for opening a PR!
Initializing the work on #5383
Blocked by aesara-devs/aesara#798EDIT: Now, I am correctly using
at.broadcast_shape
function.Depending on what your PR does, here are a few things you might want to address in the description:
This PR generalizes the MvNormal distribution.
The parameters
mu
andcov
were restricted to 2d. To remove this restriction, I created a customOp
that operates on batches of square matrices to compute their inverses in vectorized fashion.With the changes in this PR, I think that random variable object for MvNormal distribution can be created and random samples can be drawn. The
logp
computations still need to be sorted out. So, this PR is WIP.Will add them soon.