Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow prior on gpu #519

Merged
merged 5 commits into from
Jul 8, 2021
Merged

Allow prior on gpu #519

merged 5 commits into from
Jul 8, 2021

Conversation

janfb
Copy link
Contributor

@janfb janfb commented Jul 5, 2021

Relates to #515

Up to now we assumed the prior live on the cpu and we moved samples to .cpu() whenever combining log_probs from the prior and the net.

This PR now allows the prior to live on the GPU, and it asserts that the prior lives on the same device as the passed device for training

Pro: we don't need to move to .cpu() all the time
Con:

  • all the numpy based MCMC methods naturally happen on the cpu. when evaluating the theta on the prior, we now have to move it to the prior device (instead of doing net.logprob(theta).cpu())
  • now the user will get an AssertionError when the prior was not on initialised on the device. This was not the case before. Alternatively, we could introduce or deduce prior_device and take care of moving things around internally. Any opinions on that?

I haven't profiled it, but I think this way of doing it is faster than the old way because we move things less. And if we one day implement the vectorized MCMC in torch, we might get speed ups when running that on the GPU then.

closes #515

@janfb janfb self-assigned this Jul 5, 2021
@michaeldeistler
Copy link
Contributor

This is great, thanks! It'll also be useful e.g. when the prior is a previous posterior.

I do not have a strong opinion on the assertion -- I have a slight preference for automatically moving the prior and giving a warning though. This is also what we do if the simulations lie on the GPU, so it feels natural to do the same thing for the prior. As I said, no strong opinions though. Feel free to merge if you think assert is better.

@janfb
Copy link
Contributor Author

janfb commented Jul 6, 2021

thanks @michaeldeistler

regarding moving the prior to GPU, I haven't found a way for doing this actually. One would have to create a new prior instance with the parameters living on gpu. and for that one basically would need a long if-else construct to check all possible prior types...
So what I meant was to allow the case that prior lives on CPU but the training happens on GPU and to then fall back to the old solution of moving everything to CPU when the prior is involved (and maybe issue a warning that it might be better to pass a prior on the GPU). The other case, prior on GPU and training on CPU doesn't make much sense and would throw an error.
Would you agree?

@michaeldeistler
Copy link
Contributor

That makes sense.

I think my preferred way would then be to stick to the current implementation using assert. However, we should add the device argument to BoxUniform, otherwise the error-message will be confusing for users.

@codecov-commenter
Copy link

codecov-commenter commented Jul 6, 2021

Codecov Report

Merging #519 (737e4cd) into main (df147b1) will increase coverage by 0.08%.
The diff coverage is 82.92%.

❗ Current head 737e4cd differs from pull request most recent head 6e73531. Consider uploading reports for the commit 6e73531 to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##             main     #519      +/-   ##
==========================================
+ Coverage   67.70%   67.79%   +0.08%     
==========================================
  Files          55       55              
  Lines        3970     3965       -5     
==========================================
  Hits         2688     2688              
+ Misses       1282     1277       -5     
Flag Coverage Δ
unittests 67.79% <82.92%> (+0.08%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
sbi/analysis/conditional_density.py 75.00% <ø> (ø)
sbi/utils/sbiutils.py 71.09% <33.33%> (ø)
...inference/posteriors/likelihood_based_posterior.py 72.22% <60.00%> (-0.76%) ⬇️
sbi/inference/posteriors/base_posterior.py 66.00% <75.00%> (ø)
sbi/inference/posteriors/ratio_based_posterior.py 78.82% <75.00%> (+1.55%) ⬆️
sbi/inference/posteriors/direct_posterior.py 79.61% <87.50%> (+1.83%) ⬆️
sbi/inference/base.py 78.41% <100.00%> (ø)
sbi/inference/snpe/snpe_a.py 66.50% <100.00%> (ø)
sbi/mcmc/slice.py 98.57% <100.00%> (ø)
sbi/utils/metrics.py 36.90% <100.00%> (+0.76%) ⬆️
... and 2 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update df147b1...6e73531. Read the comment docs.

@janfb
Copy link
Contributor Author

janfb commented Jul 6, 2021

@michaeldeistler I refactored to potential function to handle the devices efficiently. Do you still approve?

Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good, thanks. I made two comments. I think replacing get_potential with posterior_potential would make sense, no?

sbi/inference/posteriors/direct_posterior.py Outdated Show resolved Hide resolved
@janfb janfb force-pushed the prior-on-gpu branch 3 times, most recently from 1989fbd to 5fa0770 Compare July 8, 2021 06:25
@janfb
Copy link
Contributor Author

janfb commented Jul 8, 2021

slow tests are passing now. I had to fix small things but importantly I had to remove the hmc tests with uniform prior because they were taking forever. I hope we can fix that with moving to unconstrained space via #510

@janfb janfb merged commit fb587e9 into main Jul 8, 2021
@janfb janfb deleted the prior-on-gpu branch July 8, 2021 08:37
@janfb janfb mentioned this pull request Jul 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants