-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shape conventions for all DensityEstimators #1066
Conversation
71584a2
to
b12f66a
Compare
22921d4
to
bba7fc9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks already quite good (looked through the first files). Had a few questions/comments on handling the _x_shape
.
Also all tests currently fail due to No module named 'sbi.neural_nets.categorial'
Likely because it moved to density estimators/requires to rebase on current main.
acb51c2
to
b33e1d8
Compare
Thanks a lot for the review @manuelgloeckler! All comments are addressed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
heroic effort 👏 your brain must really be in shape..
thanks a lot!
I added a couple of questions here and there, mostly to resolve my confusions.
# broadcasting of the density estimator. | ||
x = torch.as_tensor(x).reshape(-1, x.shape[-1]).unsqueeze(1) | ||
# Shape of `x` is (iid_dim, *event_shape). | ||
x = reshape_to_iid_batch_event(x, event_shape=x.shape[1:], leading_is_iid=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now we have two different notions of "iid" here, no?
The leading PyTorch iid and the "Bayesian" iid in x
? Maybe this does not matter, but just checking...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, indeed. We have two notions of iid
here. Should we rename iid_dim
to sample_dim
throughout all changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I think in the context of the SBI package it would be good to call the leading iid dimension sample_dim (sample_shape in torch
?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's change it to sample_dim
and aim to support shapes in the future.
I ran all the slow tests on this branch. Currently, the FAILED tests/mnle_test.py::test_mnle_on_device[cpu] - AssertionError: Batch shape of condition 10000 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_on_device[gpu] - AssertionError: Batch shape of condition 10000 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_api[rejection] - AssertionError: Batch shape of condition 10000 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_api[vi] - AssertionError: Batch shape of condition 256 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_accuracy_with_different_samplers_and_trials[5-mcmc] - AssertionError: Batch shape of condition 20 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_accuracy_with_different_samplers_and_trials[5-rejection] - AssertionError: Batch shape of condition 10000 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_accuracy_with_different_samplers_and_trials[5-vi] - AssertionError: Batch shape of condition 256 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_accuracy_with_different_samplers_and_trials[10-mcmc] - AssertionError: Batch shape of condition 20 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_accuracy_with_different_samplers_and_trials[10-rejection] - AssertionError: Batch shape of condition 10000 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_accuracy_with_different_samplers_and_trials[10-vi] - AssertionError: Batch shape of condition 256 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_with_experimental_conditions - AssertionError: Batch shape of condition 20 and input 3 do not match. All seem to be related to the log_prob computation in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this looks good!
As discussed, please rename iid_dim
to sample_dim
throughout to avoid confusion with the i.i.d notion in x
. This also requires renaming the function to reshape_to_sample_batch_event
I guess.
# broadcasting of the density estimator. | ||
x = torch.as_tensor(x).reshape(-1, x.shape[-1]).unsqueeze(1) | ||
# Shape of `x` is (iid_dim, *event_shape). | ||
x = reshape_to_iid_batch_event(x, event_shape=x.shape[1:], leading_is_iid=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's change it to sample_dim
and aim to support shapes in the future.
x=self.x_o, | ||
context=theta.to(self.device), | ||
) # type: ignore | ||
# TODO log_prob_iid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please link these lines in the MNLE IID issue. so that I don't forget it 🙏
89f401c
to
f9a90d0
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1066 +/- ##
==========================================
- Coverage 85.26% 76.93% -8.33%
==========================================
Files 89 90 +1
Lines 6616 6651 +35
==========================================
- Hits 5641 5117 -524
- Misses 975 1534 +559
Flags with carried forward coverage won't be shown. Click here to find out more.
|
8b00209
to
a3dce4e
Compare
Shape conventions for the
DensityEstimator
All density estimators now have a unified format of their input and output shapes.
input
must have shape(sample_dim: int, batch_dim: int, *event_shape: Tuple[int])
condition
must have shape(batch_dim: int, *event_shape: Tuple[int])
batch_dim
ofinput
andcondition
must be the same forlog_prob
.API
density_estimator.log_prob(input, condition)
density_estimator.sample(sample_shape, condition)
Examples
density_estimator.log_prob(input, condition)
density_estimator.sample(shape, condition)
Issues to be opened
MixedDensityEstimator
cannot haveembedding_net
.build_categoricalmassestimator
should have z-score option and perform z-scoring by itself.DensityEstimator
should be carryingcondition_shape
.posterior
andinference
should not carry these attributes.log_prob_iid
of MNLE not workingLimitations
Not all
DensityEstimator
s can...batch_shapes
(but onlybatch_dims
, i.e. scalar values)log_prob
without adapting thebatch_dim
of the datapointlog_prob
on data withoutiid_dim_input
. I.e.log_prob((50,3), (50,4))
might fail, butlog_prob((1,50,3), (50,4))
will workDoes this close any currently open issues?
Fixes #1041
Additional (but orthogonal) contributions
MixedDensityEstimator
independent of having to estimate the likelihood (it can, in principle, be used to estimate mixed posteriors)Checklist
guidelines
with
pytest.mark.slow
.guidelines
main
(or there are no conflicts withmain
)