Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor mixed density estimation #1203

Merged
merged 6 commits into from
Aug 2, 2024
Merged

refactor mixed density estimation #1203

merged 6 commits into from
Aug 2, 2024

Conversation

janfb
Copy link
Contributor

@janfb janfb commented Jul 24, 2024

What does this implement/fix? Explain your changes

A couple of fixes and improvements around MNLE:

  • log prob for iid data is now much faster on the discrete data because we can actually just pass the iid data as sample dim to CategoricalNet. There is no need to do tricks with the repetitions in the categorical data (i.e., log_prob_iid can be removed (it was not used anyway)).
  • allow all kinds of flows, not just nsf for MNLE
  • unify z-scoring API for build_categoricalmassestimator
  • fix embedding net handling for MNLE. It did not allow for theta embeddings before. Now it does. Importantly, one has to use a "mixed embedding" for the conditioning of the flow because the condition contains embedded theta and "raw" discrete data.

Does this close any currently open issues?

Fixes #1134
Fixes #1136
Fixes #1172

Any other comments?

the first commit is for avoiding circular imports.

@janfb janfb added the enhancement New feature or request label Jul 24, 2024
@janfb janfb requested a review from michaeldeistler July 24, 2024 15:17
@janfb janfb self-assigned this Jul 24, 2024
Copy link

codecov bot commented Jul 24, 2024

Codecov Report

Attention: Patch coverage is 98.07692% with 1 line in your changes missing coverage. Please review.

Project coverage is 75.97%. Comparing base (ba19688) to head (5aa7c2a).
Report is 10 commits behind head on main.

Files Patch % Lines
sbi/neural_nets/categorial.py 91.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1203      +/-   ##
==========================================
- Coverage   84.55%   75.97%   -8.59%     
==========================================
  Files          96       97       +1     
  Lines        7603     7668      +65     
==========================================
- Hits         6429     5826     -603     
- Misses       1174     1842     +668     
Flag Coverage Δ
unittests 75.97% <98.07%> (-8.59%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
.../neural_nets/density_estimators/categorical_net.py 97.82% <100.00%> (-0.22%) ⬇️
...nets/density_estimators/mixed_density_estimator.py 98.11% <100.00%> (+28.54%) ⬆️
sbi/neural_nets/density_estimators/zuko_flow.py 64.44% <ø> (ø)
sbi/neural_nets/flow.py 93.12% <100.00%> (ø)
sbi/neural_nets/mnle.py 100.00% <100.00%> (ø)
sbi/utils/__init__.py 100.00% <ø> (ø)
sbi/neural_nets/categorial.py 94.73% <91.66%> (+4.73%) ⬆️

... and 40 files with indirect coverage changes

@janfb janfb force-pushed the refactor-mnle branch 2 times, most recently from 3f00930 to 00dbd51 Compare July 26, 2024 15:07
Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a ton, this is awesome!

For now, this can only handle 1D discrete dimensions, right?

In the long run, the ultimate solution would be to have an Autoregressive abstraction which just concatenates a list of Estimators (can be DensityEstimator or MassEstimator) in an autoregressive way. No need to do this now ofc.

@@ -101,6 +101,8 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
Args:
input: Inputs to evaluate the log probability on. Of shape
`(sample_dim, batch_dim, *event_shape)`.
# TODO: the docstring is not correct here. in the code it seems we
do not have a sample_dim for the condition.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do not have a sample_dim for the condition, but we have a sample_dim for the input. I think the docstring is correct

@michaeldeistler
Copy link
Contributor

One more comment for the PartialEmbedding: I think we could also have the embedding_net as an attribute of MixedDensityEstimator instead of as an attribute of the DensityEstimator and the MassEstimator. Not fully sure if this would work, but it feels a bit wasteful to have the embedding_net twice.

@janfb
Copy link
Contributor Author

janfb commented Jul 30, 2024

Thanks a ton, this is awesome!

For now, this can only handle 1D discrete dimensions, right?

In the long run, the ultimate solution would be to have an Autoregressive abstraction which just concatenates a list of Estimators (can be DensityEstimator or MassEstimator) in an autoregressive way. No need to do this now ofc.

Yes, the CategoricalNet can only handle 1D because it's not straight forward to extend this. Maybe it would work better with @coschroeder 's Grassmann distribution? Although that one is only binary.

Yes, the autoregressive approach of just concatenating conditionals and using one Categorical Net for each dimension would be a really nice feature.

@janfb
Copy link
Contributor Author

janfb commented Jul 30, 2024

One more comment for the PartialEmbedding: I think we could also have the embedding_net as an attribute of MixedDensityEstimator instead of as an attribute of the DensityEstimator and the MassEstimator. Not fully sure if this would work, but it feels a bit wasteful to have the embedding_net twice.

I think both estimator need an embedding net because these are different embeddings. For the MassEstimator, it's the y / theta embedding that is sticked together with the potential standardizing net in build_categoricalmassestimator. For the DensityEstimator, it's the PartialEmbedding that contains both the discrete and the y / theta data.
But I am not sure whether that's what you mean?

Looking at this now, I think there is a problem: the density estimation build function, e.g., build_nsf will build a standardizing net for the PartialEmbedding using the entire batch of concatenated discrete and continuous data. So effectively, it will z-score the discrete data as well. E.g., when y is an image so that the y_batch passed to the embedding has shape (batch, 32, 32), including the expanded and repeated discrete x values, then the discrete x values will influence the z-scoring of the image, right?

@janfb
Copy link
Contributor Author

janfb commented Aug 2, 2024

Update: The PartialEmbedding and the expanding and repeating of discrete data is not needed anymore.

We now build the y-embedding inside of build_mnle so that we can pass the concatenation of the embedded y with the discrete data into the continuous density estimator. There, we also pass a combined_embedding_net that combines that two conditions.

This also enables us to handle sample_shape>1 by just repeating the embedded continuous condition accordingly, and concatenating it with the discrete input that has a sample_shape>1.

@janfb janfb added this to the Hackathon and release 2024 milestone Aug 2, 2024
Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome!!!!!

A few comments below, but good to go then!

sbi/neural_nets/embedding_nets.py Outdated Show resolved Hide resolved
sbi/neural_nets/mnle.py Outdated Show resolved Hide resolved
sbi/utils/nn_utils.py Outdated Show resolved Hide resolved
@janfb janfb merged commit f9ec0bd into main Aug 2, 2024
6 checks passed
@janfb janfb deleted the refactor-mnle branch August 2, 2024 15:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
2 participants