Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in univariate Ordered and SumTo1 transform logp #6903

Merged
merged 2 commits into from
Sep 13, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Sep 12, 2023

The logic introduced in #6255 was wrong because the jacobian of the multivariate transform would broadcast with the univariate logp (counting it repeatedly!).

When one applies a transform that combines information across batch dimensions, those can no longer be considered independent. An ordered uniform is in fact a multivariate RV. Accordingly, after this PR the logp of the base RVs is collapsed before adding the respective jacobian term.

With the fix it's no longer needed to distinguish between univariate and multivariate cases.

CC @TimOliverMaier


Note: It would be nice to use SymbolicRandomVariables for these kind of transforms, so that PyMC logprob inference has accurate information about the ndim support of the variables. This would also allow us to return "correct" forward draws, although I am not sure what those correspond to (rejection sampling for Ordered and something like what the ZeroSumNormal does for the SumTo1)? This is related to #6360 and #5674


📚 Documentation preview 📚: https://pymc--6903.org.readthedocs.build/en/6903/

@@ -345,56 +345,13 @@ def check_transform_elementwise_logp(self, model):
.sum()
.eval({x_val_untransf: test_array_untransf})
)
close_to(v1, v2, tol)
log_jac_det_eval = jacob_det.sum().eval({x_val_transf: test_array_transf})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing the jacobian provides a regression check

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that you might be missing some edge cases if you sum the jacob_det and the other logps. You might want to be pedantic with the test for the batch shapes and core shapes.

Copy link
Member Author

@ricardoV94 ricardoV94 Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The batch shapes would raise in the reference case with jacobian=True and on the call to model.logp(sum=False) above. In did that's what started to fail in #6897 where there was a simplex transform applied to univariate uniforms (and jacobian.shape=(2,) did not align with base_logp.shape=(3,)). The shape issue would always crop up by static shape inference or if the function was evaluated (which it was not before).

The broadcasting issue (e.g, if you do a simplex on a uniform with shape=(2,)) is now captured as well by summing the jacobian and base_logp separately, and only then adding them. I think the purpose of this test helper is fine as it is.

@codecov
Copy link

codecov bot commented Sep 12, 2023

Codecov Report

Merging #6903 (f2d644c) into main (d659848) will decrease coverage by 1.45%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6903      +/-   ##
==========================================
- Coverage   92.16%   90.71%   -1.45%     
==========================================
  Files         100      100              
  Lines       16878    16884       +6     
==========================================
- Hits        15555    15317     -238     
- Misses       1323     1567     +244     
Files Changed Coverage Δ
pymc/distributions/discrete.py 99.11% <ø> (ø)
pymc/distributions/mixture.py 94.97% <ø> (ø)
pymc/distributions/transforms.py 100.00% <100.00%> (+0.62%) ⬆️
pymc/logprob/transforms.py 94.90% <100.00%> (+0.07%) ⬆️

... and 3 files with indirect coverage changes

Copy link
Contributor

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good @ricardoV94. I left one important comment that needs to be addressed before merging this. The rest are minor nitpicks

pymc/distributions/transforms.py Outdated Show resolved Hide resolved
pymc/distributions/transforms.py Show resolved Hide resolved
pymc/distributions/transforms.py Outdated Show resolved Hide resolved
pymc/distributions/transforms.py Outdated Show resolved Hide resolved
@@ -1229,6 +1232,10 @@ def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs):
assert isinstance(value.owner.op, TransformedVariable)
original_forward_value = value.owner.inputs[1]
jacobian = transform.log_jac_det(original_forward_value, *inputs).copy()
# Check if jacobian dimensions align with logp
if jacobian.ndim < logp.ndim:
diff_ndims = logp.ndim - jacobian.ndim
Copy link
Contributor

@lucianopaz lucianopaz Sep 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love if you only had to test whether the jacobian and logp ndims were the same or slightly different. As far as I understand this is the key piece of the PR so I want to be extra sure that we're covering every angle. Could you tell me if it makes sense to take a defensive approach and test the following circumstances?

  • jacobian.ndim > logp.ndim: an example would be a Log transform on an MvNormal. It doesn't really make sense to me, but it could happen under some circumstances that I don't really know at the moment. What would be the approach to take there? To sum the across the jacobian extra dimensions?
  • jacobian.ndim <= logp.ndim and jacobian.shape != logp.shape[:jacobian.ndim]: This would mean that the "batch" dimensions of the jacobian don't match with the "batch" dimensions of the logp. To me this sounds like an error and it should at some point. The question is if it makes sense to add an assertion to the logp graph or not. It might be a very picky thing to do, and completely wasteful.
  • jacobian.ndim > logp.ndim and jacobian.shape[:logp.ndim] != logp.shape: to me, this is just like the case listed above. Something was wrong in the jacobian or the logp and an error should be raised. The question is whether it makes sense to do so at each logp evaluation.

If you can argue why you don't need to test the above cases, then I think that this change is good enough

Copy link
Member Author

@ricardoV94 ricardoV94 Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points.

jacobian.ndim > logp.ndim: an example would be a Log transform on an MvNormal. It doesn't really make sense to me, but it could happen under some circumstances that I don't really know at the moment. What would be the approach to take there? To sum the across the jacobian extra dimensions?

That touches on #6360 (comment), and in generally the jacobian can't be simply summed if the core entries are somewhat "redundant". For instance, we shouldn't count the jacobian 3 times for a scale transformed dirichlet([1, 1, 1]), but we can do it for a multivariate normal. I will add an error for this.

jacobian.ndim <= logp.ndim and jacobian.shape != logp.shape[:jacobian.ndim]: This would mean that the "batch" dimensions of the jacobian don't match with the "batch" dimensions of the logp. To me this sounds like an error and it should at some point. The question is if it makes sense to add an assertion to the logp graph or not. It might be a very picky thing to do, and completely wasteful.

The nice thing is that the static shape information is capturing more and more of these errors without any extra symbolic operations. Indeed this was what caused test failures when we bumped the PyTensor dependency in another PR and let me to realize our approach was completely wrong (even in the case it doesn't fail because it broadcasts). A correct jacobian determinant should never affect the batch dimensions, and shouldn't have "lingering" core dimensions. I think it's fine to assume they will be correctly implemented going into the future, and leave edge bugs to be captured by static shape info and making sure there is no sneaky broadcasting going on (which the changed test now does)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah we can check the broadcasting flags of the jacobian (and summed logp) match. This will also clearly raise for invalid broadcasting (which should never happen between jacobian and logp). Since we don't allow "dynamic broadcasting" anymore, even if something ends up with a dim length of 1 at runtime, and we didn't know about it, it will still fail.

tests/distributions/test_transform.py Show resolved Hide resolved
tests/distributions/test_transform.py Show resolved Hide resolved
tests/distributions/test_transform.py Outdated Show resolved Hide resolved
pymc/logprob/transforms.py Outdated Show resolved Hide resolved
@@ -345,56 +345,13 @@ def check_transform_elementwise_logp(self, model):
.sum()
.eval({x_val_untransf: test_array_untransf})
)
close_to(v1, v2, tol)
log_jac_det_eval = jacob_det.sum().eval({x_val_transf: test_array_transf})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that you might be missing some edge cases if you sum the jacob_det and the other logps. You might want to be pedantic with the test for the batch shapes and core shapes.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Sep 13, 2023

In general, I really don't like these multivariate transforms being applied to univariate RVs. I understand they were a very convenient way of implementing fancy constraints in V3, but they pose a couple of issues:

  1. We no longer know the true support dimension of variables when doing logp inference. This isn't too bad, however, because the transforms come at the "end" of the graph, and derived variables can never "extract their measurability" from other valued variables
  2. Forward draws can't obviously respect these constraints, which is surprising for many users.

I think in the long term we would be better of restricting ourselves to transforms that match the "ndim_supp" of the base RV, and don't change the meaning of the variable (A simplex-transformed dirichlet is still the same dirichlet, but an ordered transform one is in general not). The ZeroSumNormal is a good example of a more consistent approach, and I think we can do similar stuff for these transforms with distribution factories (just like Censored and Truncated do).

Details:
* Fix broadcasting bug in univariate Ordered and SumTo1 transform logp, and add explicitly check when building the graph
* Raise if univariate transform is applied to multivariate distribution
* Checks and logp reduction are applied even when jacobian is not used
@ricardoV94 ricardoV94 mentioned this pull request Sep 13, 2023
Copy link
Contributor

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Thanks @ricardoV94!

@ricardoV94 ricardoV94 merged commit eb7c3b6 into pymc-devs:main Sep 13, 2023
20 of 21 checks passed
@ricardoV94 ricardoV94 deleted the remove_ndim_supp_transforms branch September 13, 2023 14:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants