You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
For SNPE_C and likely all SNPE posteriors, the train function here performs a hardcoded transfer of all data to the "cpu" to build a flow (calling self._build_neural_net). This triggers a UserWarning from sbi/utils/user_input_checks.py:444 :
sbi/utils/user_input_checks.py:444: UserWarning: Mismatch bet
ween the device of the data fed to the embedding_net and the device of the embedding_net's weights. Fed data has device 'cpu' vs embedding_net weights have device 'cuda:0'. Automatically switching the embedding_net's device to 'cpu', which could otherwise be done manually using the line `embedding_net.to('cpu')`.
After looking through the code and debug stepping through the train function, I saw that this warning is not honored during training.
To Reproduce
Please add a minimal code example that reproduces the problem:
Expected behavior
I'd expect the UserWarning not to trigger.
Additional context
I am unclear why the hardcoded transfer to a cpu device was put there. The UserWarning originates from sbi/neural_nets/flow.py:341 in 0.22.0. Perhaps this problem is resolved in main due to the switch to zuko?
The text was updated successfully, but these errors were encountered:
Describe the bug
For SNPE_C and likely all SNPE posteriors, the
train
function here performs a hardcoded transfer of all data to the "cpu" to build a flow (callingself._build_neural_net
). This triggers a UserWarning fromsbi/utils/user_input_checks.py:444
:After looking through the code and debug stepping through the train function, I saw that this warning is not honored during training.
To Reproduce
Please add a minimal code example that reproduces the problem:
Expected behavior
I'd expect the UserWarning not to trigger.
Additional context
I am unclear why the hardcoded transfer to a cpu device was put there. The UserWarning originates from
sbi/neural_nets/flow.py:341
in 0.22.0. Perhaps this problem is resolved in main due to the switch to zuko?The text was updated successfully, but these errors were encountered: