Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shape conventions for all DensityEstimators #1066

Merged
merged 1 commit into from
Apr 9, 2024
Merged

Conversation

michaeldeistler
Copy link
Contributor

@michaeldeistler michaeldeistler commented Mar 20, 2024

Shape conventions for the DensityEstimator

All density estimators now have a unified format of their input and output shapes.

  • input must have shape (sample_dim: int, batch_dim: int, *event_shape: Tuple[int])
  • condition must have shape (batch_dim: int, *event_shape: Tuple[int])
  • batch_dim of input and condition must be the same for log_prob.

API

density_estimator.log_prob(input, condition)

input: (sample_input, batch_input, *event_shape_input)
condition: (batch_condition, *event_shape_condition)
returns: (sample_input, batch_input)
raises: batch_input != batch_condition

density_estimator.sample(sample_shape, condition)

sample_shape: (*sample_shape)
condition: (batch_condition, *event_shape_condition)
returns: (*sample_shape, batch_condition, *event_shape_input)

Examples

density_estimator.log_prob(input, condition)

(1, 1, 3), (1, 6) -> (1, 1)
(1, 5, 3), (5, 6) -> (1, 5)
(10, 5, 3), (5, 6) -> (10, 5)
(1, 1, 3), (5, 6) -> Error, batch dims must match

density_estimator.sample(shape, condition)

(1,), (1, 6) -> (1, *event_shape_input)
(1,), (5, 6) -> (1, 5, *event_shape_input)
(10,), (5, 6) -> (10, 5, *event_shape_input)
(10, 3), (5, 6) -> (10, 3, 5, *event_shape_input)

Issues to be opened

  • MixedDensityEstimator cannot have embedding_net.
  • build_categoricalmassestimator should have z-score option and perform z-scoring by itself.
  • Only DensityEstimator should be carrying condition_shape. posterior and inference should not carry these attributes.
  • log_prob_iid of MNLE not working

Limitations

Not all DensityEstimators can...

  • ...handle batch_shapes (but only batch_dims, i.e. scalar values)
  • ...evaluate one datapoint under multiple conditions with log_prob without adapting the batch_dim of the datapoint
  • ...perform log_prob on data without iid_dim_input. I.e. log_prob((50,3), (50,4)) might fail, but log_prob((1,50,3), (50,4)) will work

Does this close any currently open issues?

Fixes #1041

Additional (but orthogonal) contributions

  • Make the naming of variables in MixedDensityEstimator independent of having to estimate the likelihood (it can, in principle, be used to estimate mixed posteriors)

Checklist

  • I have read and understood the contribution
    guidelines
  • I agree with re-licensing my contribution from AGPLv3 to Apache-2.0.
  • I have commented my code, particularly in hard-to-understand areas
  • I have added tests that prove my fix is effective or that my feature works
  • I have reported how long the new tests run and potentially marked them
    with pytest.mark.slow.
  • New and existing unit tests pass locally with my changes
  • I performed linting and formatting as described in the contribution
    guidelines
  • I rebased on main (or there are no conflicts with main)

@michaeldeistler michaeldeistler changed the title [WIP] NflowsFlow follows new shape convenctions [WIP] Shape conventions for all DensityEstimators Mar 20, 2024
@michaeldeistler michaeldeistler force-pushed the useshapes branch 4 times, most recently from 71584a2 to b12f66a Compare March 21, 2024 09:22
@janfb janfb mentioned this pull request Mar 22, 2024
8 tasks
@michaeldeistler michaeldeistler force-pushed the useshapes branch 6 times, most recently from 22921d4 to bba7fc9 Compare March 25, 2024 16:33
@michaeldeistler michaeldeistler changed the title [WIP] Shape conventions for all DensityEstimators Shape conventions for all DensityEstimators Apr 2, 2024
@michaeldeistler michaeldeistler marked this pull request as ready for review April 2, 2024 07:18
Copy link
Contributor

@manuelgloeckler manuelgloeckler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks already quite good (looked through the first files). Had a few questions/comments on handling the _x_shape.

Also all tests currently fail due to No module named 'sbi.neural_nets.categorial'
Likely because it moved to density estimators/requires to rebase on current main.

@michaeldeistler
Copy link
Contributor Author

Thanks a lot for the review @manuelgloeckler! All comments are addressed.

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

heroic effort 👏 your brain must really be in shape..
thanks a lot!
I added a couple of questions here and there, mostly to resolve my confusions.

sbi/inference/posteriors/direct_posterior.py Show resolved Hide resolved
sbi/inference/posteriors/direct_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/direct_posterior.py Show resolved Hide resolved
# broadcasting of the density estimator.
x = torch.as_tensor(x).reshape(-1, x.shape[-1]).unsqueeze(1)
# Shape of `x` is (iid_dim, *event_shape).
x = reshape_to_iid_batch_event(x, event_shape=x.shape[1:], leading_is_iid=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now we have two different notions of "iid" here, no?
The leading PyTorch iid and the "Bayesian" iid in x? Maybe this does not matter, but just checking...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, indeed. We have two notions of iid here. Should we rename iid_dim to sample_dim throughout all changes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think in the context of the SBI package it would be good to call the leading iid dimension sample_dim (sample_shape in torch?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's change it to sample_dim and aim to support shapes in the future.

tests/density_estimator_test.py Show resolved Hide resolved
tests/density_estimator_test.py Show resolved Hide resolved
@manuelgloeckler
Copy link
Contributor

manuelgloeckler commented Apr 4, 2024

I ran all the slow tests on this branch. Currently, the mnle_test.py are failing with

FAILED tests/mnle_test.py::test_mnle_on_device[cpu] - AssertionError: Batch shape of condition 10000 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_on_device[gpu] - AssertionError: Batch shape of condition 10000 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_api[rejection] - AssertionError: Batch shape of condition 10000 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_api[vi] - AssertionError: Batch shape of condition 256 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_accuracy_with_different_samplers_and_trials[5-mcmc] - AssertionError: Batch shape of condition 20 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_accuracy_with_different_samplers_and_trials[5-rejection] - AssertionError: Batch shape of condition 10000 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_accuracy_with_different_samplers_and_trials[5-vi] - AssertionError: Batch shape of condition 256 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_accuracy_with_different_samplers_and_trials[10-mcmc] - AssertionError: Batch shape of condition 20 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_accuracy_with_different_samplers_and_trials[10-rejection] - AssertionError: Batch shape of condition 10000 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_accuracy_with_different_samplers_and_trials[10-vi] - AssertionError: Batch shape of condition 256 and input 2 do not match.
FAILED tests/mnle_test.py::test_mnle_with_experimental_conditions - AssertionError: Batch shape of condition 20 and input 3 do not match.

All seem to be related to the log_prob computation in CategoricalMassEstimator i.e. more specifically in
sbi/neural_nets/density_estimators/categorical_net.py:137: AssertionError

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks good!

As discussed, please rename iid_dim to sample_dim throughout to avoid confusion with the i.i.d notion in x. This also requires renaming the function to reshape_to_sample_batch_event I guess.

sbi/inference/posteriors/direct_posterior.py Outdated Show resolved Hide resolved
# broadcasting of the density estimator.
x = torch.as_tensor(x).reshape(-1, x.shape[-1]).unsqueeze(1)
# Shape of `x` is (iid_dim, *event_shape).
x = reshape_to_iid_batch_event(x, event_shape=x.shape[1:], leading_is_iid=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's change it to sample_dim and aim to support shapes in the future.

x=self.x_o,
context=theta.to(self.device),
) # type: ignore
# TODO log_prob_iid
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please link these lines in the MNLE IID issue. so that I don't forget it 🙏

@michaeldeistler michaeldeistler force-pushed the useshapes branch 7 times, most recently from 89f401c to f9a90d0 Compare April 9, 2024 12:10
Copy link

codecov bot commented Apr 9, 2024

Codecov Report

Attention: Patch coverage is 84.97110% with 26 lines in your changes are missing coverage. Please review.

Project coverage is 76.93%. Comparing base (46db263) to head (a3dce4e).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1066      +/-   ##
==========================================
- Coverage   85.26%   76.93%   -8.33%     
==========================================
  Files          89       90       +1     
  Lines        6616     6651      +35     
==========================================
- Hits         5641     5117     -524     
- Misses        975     1534     +559     
Flag Coverage Δ
unittests 76.93% <84.97%> (-8.33%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
sbi/inference/posteriors/direct_posterior.py 98.43% <100.00%> (+6.12%) ⬆️
...inference/potentials/likelihood_based_potential.py 100.00% <100.00%> (ø)
.../inference/potentials/posterior_based_potential.py 97.05% <100.00%> (+0.28%) ⬆️
sbi/inference/potentials/ratio_based_potential.py 100.00% <ø> (ø)
sbi/inference/snle/mnle.py 86.95% <100.00%> (-6.07%) ⬇️
sbi/inference/snle/snle_base.py 93.68% <100.00%> (+0.20%) ⬆️
sbi/inference/snpe/snpe_base.py 89.09% <100.00%> (+0.34%) ⬆️
.../neural_nets/density_estimators/categorical_net.py 98.03% <100.00%> (+9.15%) ⬆️
sbi/neural_nets/density_estimators/zuko_flow.py 64.44% <100.00%> (-14.51%) ⬇️
sbi/neural_nets/mnle.py 100.00% <100.00%> (ø)
... and 7 more

... and 21 files with indirect coverage changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

A unified format for shapes passed to all Estimators
3 participants