Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mnle as model class to snle #638

Merged
merged 9 commits into from
Feb 17, 2022
Merged

add mnle as model class to snle #638

merged 9 commits into from
Feb 17, 2022

Conversation

janfb
Copy link
Contributor

@janfb janfb commented Feb 7, 2022

this provides easy access to Mixed Neural Likelihood Estimation (MNLE) in sbi: MNLE is a child of SNLE and can be used by calling MNLE() and otherwise using the same API as with SNLE:

from sbi.inference import MNLE
trainer = MNLE(prior)
estimator = trainer.append_simulations(theta, x).train()
posterior = trainer.build_posterior()

It is implemented as a single density estimator class that has two separate nets for discrete and continuous data, but trains them with a single loss call, i.e., using the usual SNLE train method (thanks to @michaeldeistler for the idea).

It requires a separate mixed_likelihood_estimator_based_potential because it implements its own log_prob for iid-x to get some speed up during MCMC.

Tasks:

  • finish docstrings and comments
  • add interface to potential function for efficient evaluation with iid-trial data
  • add tutorial notebook for inference in decision-making models.

@codecov-commenter
Copy link

codecov-commenter commented Feb 7, 2022

Codecov Report

Merging #638 (8e464b3) into main (78cc11e) will increase coverage by 1.71%.
The diff coverage is 89.34%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #638      +/-   ##
==========================================
+ Coverage   74.81%   76.53%   +1.71%     
==========================================
  Files          75       77       +2     
  Lines        5667     5826     +159     
==========================================
+ Hits         4240     4459     +219     
+ Misses       1427     1367      -60     
Flag Coverage Δ
unittests 76.53% <89.34%> (+1.71%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
sbi/inference/potentials/__init__.py 100.00% <ø> (ø)
sbi/inference/snle/snle_a.py 100.00% <ø> (ø)
sbi/utils/user_input_checks_utils.py 90.25% <80.00%> (-1.80%) ⬇️
sbi/inference/snle/mnle.py 85.71% <85.71%> (ø)
sbi/neural_nets/mnle.py 88.11% <88.11%> (ø)
sbi/inference/__init__.py 100.00% <100.00%> (ø)
sbi/inference/posteriors/mcmc_posterior.py 74.82% <100.00%> (+1.58%) ⬆️
...inference/potentials/likelihood_based_potential.py 100.00% <100.00%> (ø)
sbi/inference/snle/__init__.py 100.00% <100.00%> (ø)
sbi/inference/snle/snle_base.py 91.83% <100.00%> (-0.17%) ⬇️
... and 12 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 78cc11e...8e464b3. Read the comment docs.

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the implementation a lot, thanks! I left some comments below

sbi/inference/snle/mnle.py Outdated Show resolved Hide resolved
sbi/inference/snle/mnle.py Outdated Show resolved Hide resolved
sbi/inference/snle/mnle.py Outdated Show resolved Hide resolved
sbi/inference/snle/mnle.py Outdated Show resolved Hide resolved
sbi/inference/snle/mnle.py Outdated Show resolved Hide resolved
sbi/neural_nets/mnle.py Outdated Show resolved Hide resolved
sbi/neural_nets/mnle.py Outdated Show resolved Hide resolved
tests/mnle_test.py Outdated Show resolved Hide resolved
tests/mnle_test.py Outdated Show resolved Hide resolved
sbi/neural_nets/mnle.py Show resolved Hide resolved
Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! Only minor comments below. Good to go once they are addressed.

sbi/inference/snle/mnle.py Outdated Show resolved Hide resolved
sbi/inference/snle/mnle.py Outdated Show resolved Hide resolved
sbi/inference/snle/snle_a.py Show resolved Hide resolved
sbi/neural_nets/classifier.py Show resolved Hide resolved
sbi/neural_nets/mnle.py Show resolved Hide resolved
tests/mnle_test.py Show resolved Hide resolved
tests/mnle_test.py Outdated Show resolved Hide resolved
tests/mnle_test.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants