Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restriction estimator and restricted prior now picklable #976

Merged
merged 3 commits into from
Mar 13, 2024

Conversation

danielmk
Copy link
Contributor

@danielmk danielmk commented Mar 5, 2024

This is the PR for #975. Both the restriction estimator and the restricted prior are now picklable. Both of them had separate issues relating to function definitions within functions that prevent pickle from serializing.

The refactoring I did to make them picklable were slightly bigger than initially expected. Please let me know if you disagree with any of the refactorings.

The main issue of the restriction estimator was in build_classifier, where a build_nn function was defined based on parameters. I found that the build_classifier function is not necessary, since all parameters come from the RestrictionEstimator.__init__ anyway. I now just decide in the __init__ what the instances self._build_nn should be.

The main problem of the restricted prior was in get_classifier_thresholder() which defines a function inside and returns it as a callable. To remedy the pickling issue I replaced the get_classifier_thresholder() function with a class called AcceptRejectFunction that acts like the accept_reject_fn through the __call__(theta) method. Instances of this class thereby act like the accept_reject_fn but are parameterized through the __init__ method instead of a surrounding get_classifier_thresholder() function.

I ran the script in #975 with the refactoring to make sure that inference is still reproducible and the pickling now works. I ran black and isort but pyright gave a large number of errors. I did not run tests, since it is not mentioned in the contributing guidelines (as a sidenote, the contributing guidelines might be out of date, as I found no environment.yml, so I did pip install -e . in a fresh environment instead, python=3.9). I would do documentation if you approve of these refactorings.

@janfb
Copy link
Contributor

janfb commented Mar 5, 2024

Hi @danielmk, thanks a lot for the PR and the detailed description of the changes!

I assigned @michaeldeistler for the review, as he is more on top of the code for the restriction estimator.

Thanks also for the pointer about the outdated contribution guidelines, we recently changed the project setup and will fix this soon (see #945).

Copy link

codecov bot commented Mar 5, 2024

Codecov Report

Attention: Patch coverage is 62.68657% with 25 lines in your changes are missing coverage. Please review.

Project coverage is 76.19%. Comparing base (17f3033) to head (b9e0871).
Report is 22 commits behind head on main.

Files Patch % Lines
sbi/utils/restriction_estimator.py 62.68% 25 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #976      +/-   ##
==========================================
+ Coverage   76.17%   76.19%   +0.02%     
==========================================
  Files          83       83              
  Lines        6406     6420      +14     
==========================================
+ Hits         4880     4892      +12     
- Misses       1526     1528       +2     
Flag Coverage Δ
unittests 76.19% <62.68%> (+0.02%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi there, thanks a lot! Please run also the tutorial https://github.com/sbi-dev/sbi/blob/main/tutorials/08_restriction_estimator.ipynb and make sure that everything works. In addition, could you run pytest tests/inference_with_NaN_simulator_test.py`?

Beyond that, almost all looks good, I only left a small comment below.

Thanks a ton for this PR!
Michael

predictions = probs_valid > probs_invalid
return predictions.bool()

if print_fp_rate:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did this get removed? self._print_fp_rate seems to be unused now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleting if print_fp_rate: was a mistake on my end. I have now included print_false_positive_rate as a method of AcceptRejectFunction.
Now, if self._print_fp_rate, the function is called. The refactoring is a bit awkward on this one, technically the method does not need any input parameters other than self but I opted to keep the input parameters so they remain explicit during the call and for consistency with the previous implementation. I also checked that setting it to true actually prints the fp rate. It’s also worth nothing that there seems to be no good way right now for a user to print the fp rate but since that’s unrelated to pickling I won’t go into detail here.
Fixing commit is below.

@danielmk
Copy link
Contributor Author

danielmk commented Mar 6, 2024

I ran the tutorial as requested and it does not reproduce the exact sampling from here (I guess torch does not allow inter-hardware reproducibility?) but overall is seems to work.

Here is some of the printout:

Simulation outputs:  tensor([[ 0.0538, -0.1295],
        [ 0.7811, -0.1608],
        [ 0.8663,  0.3622],
        ...,
        [    nan,     nan],
        [    nan,     nan],
        [ 1.7638,  0.1825]])
The `RestrictedPrior` rejected 49.5%
                of prior samples. You will get a speed-up of
                98.0%.

Sampling from the restricted prior then does not produce invalids:

Simulation outputs:  tensor([[ 1.0857,  0.2360],
        [ 0.8034,  1.5441],
        [ 0.6469,  1.9206],
        ...,
        [ 0.2529, -1.0084],
        [ 0.6547,  1.6762],
        [ 1.8734, -1.3236]])

Overall the resulting SBI looks plausible:
sbi

Finally I ran pytest tests/inference_with_NaN_simulator_test.py, the short test summary:

XFAIL tests/inference_with_NaN_simulator_test.py::test_inference_with_nan_simulator[SNLE_A-0.05]
XFAIL tests/inference_with_NaN_simulator_test.py::test_inference_with_nan_simulator[SNRE_B-0.05]
8 passed, 2 xfailed, 6 warnings in 138.62s (0:02:18)

There are failures but actually the main branch test summary looks identical:

`XFAIL tests/inference_with_NaN_simulator_test.py::test_inference_with_nan_simulator[SNLE_A-0.05]
XFAIL tests/inference_with_NaN_simulator_test.py::test_inference_with_nan_simulator[SNRE_B-0.05]
8 passed, 2 xfailed, 6 warnings in 141.36s (0:02:21)`

Let me know if any of this requires fixing.

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @danielmk!

Thanks for all the updates, and sorry for the delay in my response (I was on vacation). All looks good from my side!

Don't worry about the sampling not being bitwise identical.

Regaring failing tests: xfailed means that the test is supposed to fail and it indeed fails, so all is good here as well.

@michaeldeistler
Copy link
Contributor

@danielmk unfortunately, tests are failing because of black, so please run

black sbi
black tests
isort sbi
isort tests

I will merge the PR once this is fixed :) Thanks a ton!

@danielmk
Copy link
Contributor Author

I ran black and isort and pushed. You're welcome, I think sbi is very important for computational neuroscience so I am happy to contribute.

@michaeldeistler michaeldeistler merged commit 3c1bb5a into sbi-dev:main Mar 13, 2024
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants