-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Analytic sampling for conditional posterior instances trained with MDNs. #458
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! I made a few preliminary comments. Nothing for you to do yet. Instead, we should re-write the MDN class in pyknos
. This will avoid code repetition of sample
and log_prob
methods. I will take care of re-writing pyknos and will let you know. This might take a few days though.
You can track progress here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okey, the PR is merged. We can now use .sample_mog()
and .log_prob_mog()
and replace basically the entire probability evaluations and sample functions by just calling these methods. Let me know in case I am missing something, and thanks again for taking over this endeavor!
So much for "This might take a few days though." 😄 |
sbi/utils/conditional_density.py
Outdated
log_factor = torch.log(self.leakage_correction(x=self.default_x)) | ||
return torch.log(torch.sum(pdf, axis=1)) - log_factor | ||
self.net.eval() # leakage correction requires eval mode | ||
log_factor = torch.log(self.leakage_correction(x=self.default_x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the leakage correction does not work for the conditioned MDNs, because it uses samples from self
i.e. the full posterior rather then from the new, conditioned posterior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extract_and_transform_mog
+ __conditionalise
can now be used to get new sets of mog logits, means etc. for an arbitrary condition which makes them compatible with sample_mog
and log_prob_mog
.
Next step: Introduce extract_and_transform_mog
+ __conditionalise
into DirectPosterior
.
Then insert them into sample_conditional
together with sample_mog
and sample_posterior_within_prior
as well as into
log_prob_conditional
together with log_prob_mog
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have integrated everything into DirectPosterior as discussed and according to my tests it works.
Also... I just commited a bunch of files for review on accident. The only thing important for merging is the direct_posterior.py.
Looking forward to your remarks. Hope this round of review is the last :)
Codecov Report
@@ Coverage Diff @@
## main #458 +/- ##
==========================================
- Coverage 67.96% 66.96% -1.00%
==========================================
Files 56 56
Lines 4117 4190 +73
==========================================
+ Hits 2798 2806 +8
- Misses 1319 1384 +65
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code should be ready for final review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, I finally got round to implementing everything. I hope it still plays nicely with the new version of sbi that has come out in the meantime. I have also added something resembling a unit test, but I am not sure how to integrate it properly. Help would be greatly appreciated. I have just plugged it into an extra file for now. Looking forward to your comments.
Best,
Jonas
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! Super minor things only, we can merge it after this!
I have just addressed your latest comments. Hope that it can be merged now. :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great and will be super useful, thanks a lot for all your efforts!
Can be merged once merge conflicts are resolved. |
I just resolved all conflicts, if it passes the tests, I guess you can merge it :) |
Hey,
I added code to analytically sample and condition
DirectPosterior
instances trained on MDNs. Wrapping aDirectPosterior
instance that has been trained using an MDN in this way, replaces the.log_prob()
and.sample()
methods with analytical ones. This means it should be compatible with the rest of sbi.Example
After training the posterior...
...it can be wrapped, ...
...and conditioned.
Variables of interest are represented by
'nan'
s.Then
cond_posterior
can be sampled and evaluated as before.I am sure this could somehow be integrated into the
DirectPosterior
class directly as well, since theMDNPosterior
inherits from it.Let me know of any issues, I'd be happy to help integrating it into sbi. :)