-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: remove emb net device handling, refactor get_numel #1186
Conversation
07f3fc0
to
06f83d5
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1186 +/- ##
==========================================
- Coverage 84.53% 75.56% -8.97%
==========================================
Files 94 95 +1
Lines 7571 7576 +5
==========================================
- Hits 6400 5725 -675
- Misses 1171 1851 +680
Flags with carried forward coverage won't be shown. Click here to find out more.
|
06f83d5
to
3a82707
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! A few questions and (optional) suggestions below.
479c0bb
to
5afba70
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great now, thanks a lot! I think there is one small bug, please have a look. Good to go afterwards though.
Hi, I've been trying to use sbi and torch for not very long time. I used the following script to train SNPE. The input net and training data are all of device "cuda". Based on what you updated, does it mean that the SNPE should only trained on cpu?
Here are the errors I got from it:
I tried with cpu device as well, but got:
I'm not sure if I understood it correctly, could you maybe help? I really appriciate that. |
Context
The starting point for this PR is #1161, the incorrect warning that embedding net and data device do not match.
On the way I realized that we are treating the
embedding_net
as separate net that can have its own device, different from the actual net. I think this does not make sense.In general, the device handling should be centralized, e.g., have a single entry point. At the moment, this entry point is the inference object, e.g.,
SNPE(..., device=device)
. But are the different scenarios:SNPE
: all good, device handling is centralized via thedevice
device
passed toSNPE
.posterior_nn
to build a flow with an embedding net.posterior_nn
normally returns a net on the cpu. but if theembedding_net
passed by the user is on a different device, things might crash.My suggestions
EDIT: Does not make sense because in the standard case it will be in cpu and be moved to training device later, so there will be a mismatch. So we can either move the passed net to cpu entirely (bad), or, in those few cases where users pass large nets that have to be on the GPU, accept potential device mismatches.
posterior_nn
etc, that the passed embedding net is on the cpu, or we move it there.EDIT: I will add a function that checks the embedding net device and if it is not on cpu, it warns and moves it there.
What this PR does so far
embedding_net
device checkingbuild_posterior
(and add test)get_numel
to be used across the neural net factory. I had to put it into a separate utils file because putting it intosbiutils
ortorchutils
causes circular imports 😵fixes #1161