-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
restriction estimator and restricted prior now picklable #976
Conversation
Hi @danielmk, thanks a lot for the PR and the detailed description of the changes! I assigned @michaeldeistler for the review, as he is more on top of the code for the restriction estimator. Thanks also for the pointer about the outdated contribution guidelines, we recently changed the project setup and will fix this soon (see #945). |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #976 +/- ##
==========================================
+ Coverage 76.17% 76.19% +0.02%
==========================================
Files 83 83
Lines 6406 6420 +14
==========================================
+ Hits 4880 4892 +12
- Misses 1526 1528 +2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi there, thanks a lot! Please run also the tutorial https://github.com/sbi-dev/sbi/blob/main/tutorials/08_restriction_estimator.ipynb
and make sure that everything works. In addition, could you run pytest tests/inference_with_NaN_simulator_test.py`?
Beyond that, almost all looks good, I only left a small comment below.
Thanks a ton for this PR!
Michael
predictions = probs_valid > probs_invalid | ||
return predictions.bool() | ||
|
||
if print_fp_rate: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why did this get removed? self._print_fp_rate
seems to be unused now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deleting if print_fp_rate: was a mistake on my end. I have now included print_false_positive_rate
as a method of AcceptRejectFunction
.
Now, if self._print_fp_rate
, the function is called. The refactoring is a bit awkward on this one, technically the method does not need any input parameters other than self but I opted to keep the input parameters so they remain explicit during the call and for consistency with the previous implementation. I also checked that setting it to true actually prints the fp rate. It’s also worth nothing that there seems to be no good way right now for a user to print the fp rate but since that’s unrelated to pickling I won’t go into detail here.
Fixing commit is below.
I ran the tutorial as requested and it does not reproduce the exact sampling from here (I guess torch does not allow inter-hardware reproducibility?) but overall is seems to work. Here is some of the printout:
Sampling from the restricted prior then does not produce invalids:
Overall the resulting SBI looks plausible: Finally I ran
There are failures but actually the main branch test summary looks identical:
Let me know if any of this requires fixing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @danielmk!
Thanks for all the updates, and sorry for the delay in my response (I was on vacation). All looks good from my side!
Don't worry about the sampling not being bitwise identical.
Regaring failing tests: xfailed
means that the test is supposed to fail and it indeed fails, so all is good here as well.
@danielmk unfortunately, tests are failing because of
I will merge the PR once this is fixed :) Thanks a ton! |
I ran black and isort and pushed. You're welcome, I think sbi is very important for computational neuroscience so I am happy to contribute. |
This is the PR for #975. Both the restriction estimator and the restricted prior are now picklable. Both of them had separate issues relating to function definitions within functions that prevent pickle from serializing.
The refactoring I did to make them picklable were slightly bigger than initially expected. Please let me know if you disagree with any of the refactorings.
The main issue of the restriction estimator was in
build_classifier
, where abuild_nn
function was defined based on parameters. I found that thebuild_classifier
function is not necessary, since all parameters come from theRestrictionEstimator.__init__
anyway. I now just decide in the__init__
what the instancesself._build_nn
should be.The main problem of the restricted prior was in
get_classifier_thresholder()
which defines a function inside and returns it as a callable. To remedy the pickling issue I replaced theget_classifier_thresholder()
function with a class calledAcceptRejectFunction
that acts like theaccept_reject_fn
through the__call__(theta)
method. Instances of this class thereby act like theaccept_reject_fn
but are parameterized through the__init__
method instead of a surroundingget_classifier_thresholder()
function.I ran the script in #975 with the refactoring to make sure that inference is still reproducible and the pickling now works. I ran
black
andisort
but pyright gave a large number of errors. I did not run tests, since it is not mentioned in the contributing guidelines (as a sidenote, the contributing guidelines might be out of date, as I found no environment.yml, so I didpip install -e .
in a fresh environment instead, python=3.9). I would do documentation if you approve of these refactorings.