Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting rid of need for get_log_abs_det_jacobian function #15

Closed
wants to merge 3 commits into from

Conversation

janfb
Copy link
Contributor

@janfb janfb commented Jul 13, 2021

Long story short, this PR proposes a fix for the problem below by constructing all transforms such that the sum over the second dimension of the parameters Tensor passed to the ladj method.

At the moment we have a helper function get_log_abs_det_jacobian (ladj) to make sure the ladj is returned summed over the parameter dimension. I looked into this a bit and understood why sometimes the sum is performed and sometimes not.

The "raw" transforms from torch.distributions.transforms assume event_dim=0 when initialised, so the transforms constructed here:

sbibm/sbibm/utils/pyro.py

Lines 99 to 102 in 89755d2

if automatic_transform_enabled:
transforms[name] = biject_to(fn.support).inv
else:
transforms[name] = dist.transforms.identity_transform

will all have event_dim=0.
Thus, given a parameter with shape (batch_size, dim_parameters), accumulating the ladj to get a transform-corrected log_prob here:

sbibm/sbibm/utils/pyro.py

Lines 211 to 214 in 89755d2

for name, t in self.transforms.items():
log_joint -= get_log_abs_det_jacobian(
t, params_constrained[name], params[name]
)

would result in a ladj shape of (batch_size, dim_parameters) instead of (batch_size, ) (unless we use the get_log_abs_det_jacobian as we do it currently).

But I think there is a cleaner solution. We noticed before that in some cases the sum over the second dimension is indeed calculated. This is the case when the transform is defined as an IndependentTransform with reinterpreted_batch_ndims=1, and this happens automatically when calling biject_to on a BoxUniform prior because this prior has reinterpreted batch dims already.

With this fix we would not need get_log_abs_det_jacobian anymore. However, we first have to check it is bullet proof, e.g., what happens if the prior has more than two dimensions, is this case even allowed in sbibm?

@jan-matthis jan-matthis changed the title fix log abs det jakobijn event dim Getting rid of need for get_log_abs_det_jacobian function Jul 13, 2021
@janfb
Copy link
Contributor Author

janfb commented Nov 8, 2021

close in favour of #27

@janfb janfb closed this Nov 8, 2021
@janfb janfb deleted the fix-ladj branch February 16, 2023 16:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

1 participant