Getting rid of need for get_log_abs_det_jacobian
function
#15
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Long story short, this PR proposes a fix for the problem below by constructing all transforms such that the sum over the second dimension of the parameters
Tensor
passed to theladj
method.At the moment we have a helper function
get_log_abs_det_jacobian
(ladj
) to make sure theladj
is returned summed over the parameter dimension. I looked into this a bit and understood why sometimes the sum is performed and sometimes not.The "raw" transforms from
torch.distributions.transforms
assumeevent_dim=0
when initialised, so the transforms constructed here:sbibm/sbibm/utils/pyro.py
Lines 99 to 102 in 89755d2
will all have
event_dim=0
.Thus, given a parameter with shape
(batch_size, dim_parameters)
, accumulating theladj
to get a transform-correctedlog_prob
here:sbibm/sbibm/utils/pyro.py
Lines 211 to 214 in 89755d2
would result in a
ladj
shape of(batch_size, dim_parameters)
instead of(batch_size, )
(unless we use theget_log_abs_det_jacobian
as we do it currently).But I think there is a cleaner solution. We noticed before that in some cases the sum over the second dimension is indeed calculated. This is the case when the transform is defined as an
IndependentTransform
withreinterpreted_batch_ndims=1
, and this happens automatically when callingbiject_to
on aBoxUniform
prior because this prior has reinterpreted batch dims already.With this fix we would not need
get_log_abs_det_jacobian
anymore. However, we first have to check it is bullet proof, e.g., what happens if theprior
has more than two dimensions, is this case even allowed insbibm
?