-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bug in univariate Ordered and SumTo1 transform logp #6903
Fix bug in univariate Ordered and SumTo1 transform logp #6903
Conversation
@@ -345,56 +345,13 @@ def check_transform_elementwise_logp(self, model): | |||
.sum() | |||
.eval({x_val_untransf: test_array_untransf}) | |||
) | |||
close_to(v1, v2, tol) | |||
log_jac_det_eval = jacob_det.sum().eval({x_val_transf: test_array_transf}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Testing the jacobian provides a regression check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that you might be missing some edge cases if you sum
the jacob_det
and the other logp
s. You might want to be pedantic with the test for the batch shapes and core shapes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The batch shapes would raise in the reference case with jacobian=True and on the call to model.logp(sum=False)
above. In did that's what started to fail in #6897 where there was a simplex transform applied to univariate uniforms (and jacobian.shape=(2,) did not align with base_logp.shape=(3,)). The shape issue would always crop up by static shape inference or if the function was evaluated (which it was not before).
The broadcasting issue (e.g, if you do a simplex on a uniform with shape=(2,)) is now captured as well by summing the jacobian and base_logp separately, and only then adding them. I think the purpose of this test helper is fine as it is.
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6903 +/- ##
==========================================
- Coverage 92.16% 90.71% -1.45%
==========================================
Files 100 100
Lines 16878 16884 +6
==========================================
- Hits 15555 15317 -238
- Misses 1323 1567 +244
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good @ricardoV94. I left one important comment that needs to be addressed before merging this. The rest are minor nitpicks
pymc/logprob/transforms.py
Outdated
@@ -1229,6 +1232,10 @@ def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs): | |||
assert isinstance(value.owner.op, TransformedVariable) | |||
original_forward_value = value.owner.inputs[1] | |||
jacobian = transform.log_jac_det(original_forward_value, *inputs).copy() | |||
# Check if jacobian dimensions align with logp | |||
if jacobian.ndim < logp.ndim: | |||
diff_ndims = logp.ndim - jacobian.ndim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would love if you only had to test whether the jacobian
and logp
ndims
were the same or slightly different. As far as I understand this is the key piece of the PR so I want to be extra sure that we're covering every angle. Could you tell me if it makes sense to take a defensive approach and test the following circumstances?
jacobian.ndim > logp.ndim
: an example would be aLog
transform on anMvNormal
. It doesn't really make sense to me, but it could happen under some circumstances that I don't really know at the moment. What would be the approach to take there? To sum the across the jacobian extra dimensions?jacobian.ndim <= logp.ndim
andjacobian.shape != logp.shape[:jacobian.ndim]
: This would mean that the "batch" dimensions of the jacobian don't match with the "batch" dimensions of thelogp
. To me this sounds like an error and it should at some point. The question is if it makes sense to add an assertion to thelogp
graph or not. It might be a very picky thing to do, and completely wasteful.jacobian.ndim > logp.ndim
andjacobian.shape[:logp.ndim] != logp.shape
: to me, this is just like the case listed above. Something was wrong in thejacobian
or thelogp
and an error should be raised. The question is whether it makes sense to do so at eachlogp
evaluation.
If you can argue why you don't need to test the above cases, then I think that this change is good enough
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good points.
jacobian.ndim > logp.ndim: an example would be a Log transform on an MvNormal. It doesn't really make sense to me, but it could happen under some circumstances that I don't really know at the moment. What would be the approach to take there? To sum the across the jacobian extra dimensions?
That touches on #6360 (comment), and in generally the jacobian can't be simply summed if the core entries are somewhat "redundant". For instance, we shouldn't count the jacobian 3 times for a scale transformed dirichlet([1, 1, 1]), but we can do it for a multivariate normal. I will add an error for this.
jacobian.ndim <= logp.ndim and jacobian.shape != logp.shape[:jacobian.ndim]: This would mean that the "batch" dimensions of the jacobian don't match with the "batch" dimensions of the logp. To me this sounds like an error and it should at some point. The question is if it makes sense to add an assertion to the logp graph or not. It might be a very picky thing to do, and completely wasteful.
The nice thing is that the static shape information is capturing more and more of these errors without any extra symbolic operations. Indeed this was what caused test failures when we bumped the PyTensor dependency in another PR and let me to realize our approach was completely wrong (even in the case it doesn't fail because it broadcasts). A correct jacobian determinant should never affect the batch dimensions, and shouldn't have "lingering" core dimensions. I think it's fine to assume they will be correctly implemented going into the future, and leave edge bugs to be captured by static shape info and making sure there is no sneaky broadcasting going on (which the changed test now does)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah we can check the broadcasting flags of the jacobian (and summed logp) match. This will also clearly raise for invalid broadcasting (which should never happen between jacobian and logp). Since we don't allow "dynamic broadcasting" anymore, even if something ends up with a dim length of 1 at runtime, and we didn't know about it, it will still fail.
@@ -345,56 +345,13 @@ def check_transform_elementwise_logp(self, model): | |||
.sum() | |||
.eval({x_val_untransf: test_array_untransf}) | |||
) | |||
close_to(v1, v2, tol) | |||
log_jac_det_eval = jacob_det.sum().eval({x_val_transf: test_array_transf}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that you might be missing some edge cases if you sum
the jacob_det
and the other logp
s. You might want to be pedantic with the test for the batch shapes and core shapes.
In general, I really don't like these multivariate transforms being applied to univariate RVs. I understand they were a very convenient way of implementing fancy constraints in V3, but they pose a couple of issues:
I think in the long term we would be better of restricting ourselves to transforms that match the "ndim_supp" of the base RV, and don't change the meaning of the variable (A simplex-transformed dirichlet is still the same dirichlet, but an ordered transform one is in general not). The ZeroSumNormal is a good example of a more consistent approach, and I think we can do similar stuff for these transforms with distribution factories (just like Censored and Truncated do). |
a880d25
to
b83f69a
Compare
Details: * Fix broadcasting bug in univariate Ordered and SumTo1 transform logp, and add explicitly check when building the graph * Raise if univariate transform is applied to multivariate distribution * Checks and logp reduction are applied even when jacobian is not used
b83f69a
to
81d18af
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Thanks @ricardoV94!
The logic introduced in #6255 was wrong because the jacobian of the multivariate transform would broadcast with the univariate logp (counting it repeatedly!).
When one applies a transform that combines information across batch dimensions, those can no longer be considered independent. An ordered uniform is in fact a multivariate RV. Accordingly, after this PR the logp of the base RVs is collapsed before adding the respective jacobian term.
With the fix it's no longer needed to distinguish between
univariate
andmultivariate
cases.CC @TimOliverMaier
Note: It would be nice to use
SymbolicRandomVariable
s for these kind of transforms, so that PyMC logprob inference has accurate information about the ndim support of the variables. This would also allow us to return "correct" forward draws, although I am not sure what those correspond to (rejection sampling for Ordered and something like what the ZeroSumNormal does for the SumTo1)? This is related to #6360 and #5674📚 Documentation preview 📚: https://pymc--6903.org.readthedocs.build/en/6903/