Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add MAF with RQS as density estimator #819

Merged
merged 6 commits into from
Mar 15, 2023
Merged

add MAF with RQS as density estimator #819

merged 6 commits into from
Mar 15, 2023

Conversation

ImahnShekhzadeh
Copy link
Contributor

Hi,

for SNPE, these were the options: mdn, made, maf and nsf (cf. the function posterior_nn in sbi/utils/get_nn_models.py).

The MAF uses affine-linear diffeomorphisms (and no RQS), the NSF uses coupling flows with RQS. I added MAFs with RQS (which is the new option maf_rqs).

About black formatting: I'm unsure what the maximum line length is in this project, so I haven't black formatted my code yet.

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks great! I think we can also use these for SNLE, see comment below. Also, regarding tests, could you add (SNPE_C, "maf_rqs", "direct") and (SNLE, "maf_rqs", "slice"), here?

Regarding black: we use a line length of 88 (default) with the latest black version. The black test is currently failing.

elif model == "maf_rqs":
return build_maf_rqs(
batch_x=batch_theta, batch_y=batch_x, **kwargs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to not add this to likelihood_nn as well?

@janfb
Copy link
Contributor

janfb commented Mar 13, 2023

Great, thanks for adding this @ImahnShekhzadeh !

We should add maf_rqs to the functional tests here. However, I noticed that this is not straight forward at the moment.
I made a PR to change this, #820 . Please rebase on main once that PR is merged and then add maf_rqs to the list of tests.

@janfb
Copy link
Contributor

janfb commented Mar 13, 2023

#820 is merged now.
@ImahnShekhzadeh please rebase your branch on the most recent main and then add maf_rqs to the list of tested density estimator here:
https://github.com/mackelab/sbi/blob/main/tests/linearGaussian_snpe_test.py#L149

The test is marked as slow such that it will not be executed as part of our CI. Thus, please run the test locally, e.g., using

pytest -s tests/linearGaussian_snpe_test.py::test_density_estimators_on_linearGaussian

I tested this on your branch and got the following error:

inputs = tensor([[ 0.8851,  0.4588,  0.6310,  0.8696],
        [-0.8224, -0.3964, -0.8029,  1.5897]])
unnormalized_widths = tensor([[[-0.0842, -0.0244, -0.0431, -0.0043, -0.0082, -0.0356,  0.0360,
           0.0119,  0.0080,  0.0542],
       ...-0.0170,  0.0520, -0.0581, -0.0040, -0.0741, -0.0116,
           0.0008,  0.0562, -0.1210]]], grad_fn=<SliceBackward0>)
unnormalized_heights = tensor([[[ 2.6453e-02,  5.3341e-02,  7.3937e-02, -3.5198e-02,  5.7186e-02,
          -8.8525e-02, -5.0385e-02, -2.8635...501e-01,
           2.0109e-02,  8.8173e-03, -7.4991e-02,  8.3034e-02,  2.1044e-02]]],
       grad_fn=<SliceBackward0>)
unnormalized_derivatives = tensor([[[ 0.0386, -0.0607,  0.1006, -0.1117,  0.1289, -0.0991, -0.1373,
          -0.0095,  0.1147,  0.1365, -0.1023]...-0.1592,  0.1385,  0.0161, -0.0628, -0.0281,
           0.1999,  0.0236, -0.0915, -0.1208]]], grad_fn=<SliceBackward0>)
inverse = False, left = 0.0, right = 1.0, bottom = 0.0, top = 1.0, min_bin_width = 0.001, min_bin_height = 0.001, min_derivative = 0.001

    def rational_quadratic_spline(
        inputs,
        unnormalized_widths,
        unnormalized_heights,
        unnormalized_derivatives,
        inverse=False,
        left=0.0,
        right=1.0,
        bottom=0.0,
        top=1.0,
        min_bin_width=DEFAULT_MIN_BIN_WIDTH,
        min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
        min_derivative=DEFAULT_MIN_DERIVATIVE,
    ):
        if torch.min(inputs) < left or torch.max(inputs) > right:
>           raise InputOutsideDomain()
E           nflows.transforms.base.InputOutsideDomain

../../opt/anaconda3/envs/sbi/lib/python3.8/site-packages/nflows/transforms/splines/rational_quadratic.py:79: InputOutsideDomain

@ImahnShekhzadeh
Copy link
Contributor Author

ImahnShekhzadeh commented Mar 13, 2023

@michaeldeistler I added the MAF with RQS to SNLE, added the test statements and I hope the black test runs through now.

@janfb I added maf_rqs to the list, but this results in a conflict, which should be quick to resolve. It's expected that the test doesn't run through, since by default, the RQS use no tails (i.e. are defined on [-B, B] only) and the tail bound (B) is by default set to 1, yet one of the input elements is greater than 0. I verified this by setting (only locally) tails = 'linear'. Afterwards, I faced another problem (relating to the classifier score being too far away from 0.5), which I solved by increasing the tail_bound to 3 (again only locally), and the test runs through (which makes sense to me given that the prior is a multivariate normal distribution).

Please let me know if anything else should be necessary for a merging.

@michaeldeistler
Copy link
Contributor

Any reason to not use tails = 'linear' and tail_bound=3 as defaults? (same as nsf)

@ImahnShekhzadeh
Copy link
Contributor Author

@michaeldeistler I had used tails=None and tail_bound=1, since these are the default values in the nflows package. But I agree, to make this more consistent with nsf, I adjusted the default values to tails='linear' and tail_bound=3.

@michaeldeistler
Copy link
Contributor

Awesome, thank you!

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the update!
The functional test with SNPE is still failing with c2st ~ 0.6. I suggest to increase num_simulations=2100

any ideas why maf_rqs need more training data than the other flows?

@ImahnShekhzadeh
Copy link
Contributor Author

ImahnShekhzadeh commented Mar 14, 2023

@janfb I just ran the test for three times, it ran through for me in all cases. The c2st for snpe_maf_rqs (with 2000 simulations) was 0.55 in all three cases. Would you still like me increase the number of simulations to 2100?

@janfb
Copy link
Contributor

janfb commented Mar 14, 2023

Oh, ok. That's probably due to random init and different RNG settings on our machines. So, no, let's keep it at 2000 then.

@janfb
Copy link
Contributor

janfb commented Mar 14, 2023

One last thing: could you please run black v23.1.0 and isort v5.11.5:
black .
and
isort .

locally and then push again?

@ImahnShekhzadeh
Copy link
Contributor Author

@janfb Sure, it's done.

@janfb janfb self-requested a review March 15, 2023 07:15
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be merged once CI tests are passing.

@codecov-commenter
Copy link

Codecov Report

Merging #819 (21f6180) into main (31c6076) will decrease coverage by 0.25%.
The diff coverage is 24.13%.

📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

@@            Coverage Diff             @@
##             main     #819      +/-   ##
==========================================
- Coverage   74.83%   74.59%   -0.25%     
==========================================
  Files          80       80              
  Lines        6196     6222      +26     
==========================================
+ Hits         4637     4641       +4     
- Misses       1559     1581      +22     
Flag Coverage Δ
unittests 74.59% <24.13%> (-0.25%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
sbi/neural_nets/flow.py 80.37% <9.09%> (-18.45%) ⬇️
sbi/utils/get_nn_models.py 89.09% <71.42%> (-3.07%) ⬇️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@janfb janfb changed the title For SNPE, MAF with RQS is available now. add MAF with RQS as density estimator Mar 15, 2023
@janfb janfb merged commit 7a26b70 into sbi-dev:main Mar 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants