Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: remove emb net device handling, refactor get_numel #1186

Merged
merged 6 commits into from
Jul 2, 2024

Conversation

janfb
Copy link
Contributor

@janfb janfb commented Jun 25, 2024

Context

The starting point for this PR is #1161, the incorrect warning that embedding net and data device do not match.
On the way I realized that we are treating the embedding_net as separate net that can have its own device, different from the actual net. I think this does not make sense.

In general, the device handling should be centralized, e.g., have a single entry point. At the moment, this entry point is the inference object, e.g., SNPE(..., device=device). But are the different scenarios:

  1. user passes a model str to SNPE: all good, device handling is centralized via the device
  2. user passes a custom net, which could be on a device already. Then this device must match the device passed to SNPE.
  3. user uses e.g., posterior_nn to build a flow with an embedding net. posterior_nn normally returns a net on the cpu. but if the embedding_net passed by the user is on a different device, things might crash.

My suggestions

  1. we assert that the device of a passed net matches the device passed to the inference object.
    EDIT: Does not make sense because in the standard case it will be in cpu and be moved to training device later, so there will be a mismatch. So we can either move the passed net to cpu entirely (bad), or, in those few cases where users pass large nets that have to be on the GPU, accept potential device mismatches.
  2. we assert in posterior_nn etc, that the passed embedding net is on the cpu, or we move it there.
    EDIT: I will add a function that checks the embedding net device and if it is not on cpu, it warns and moves it there.

What this PR does so far

  • remove all the separate embedding_net device checking
  • fix a small bug when inferring the device in build_posterior (and add test)
  • unify get_numel to be used across the neural net factory. I had to put it into a separate utils file because putting it into sbiutils or torchutils causes circular imports 😵
  • add 2. from above
  • add 3. from above.

fixes #1161

@janfb janfb force-pushed the fix-device-warnings branch from 07f3fc0 to 06f83d5 Compare June 26, 2024 16:34
Copy link

codecov bot commented Jun 26, 2024

Codecov Report

Attention: Patch coverage is 92.45283% with 4 lines in your changes missing coverage. Please review.

Project coverage is 75.56%. Comparing base (337f072) to head (c0f5a6d).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1186      +/-   ##
==========================================
- Coverage   84.53%   75.56%   -8.97%     
==========================================
  Files          94       95       +1     
  Lines        7571     7576       +5     
==========================================
- Hits         6400     5725     -675     
- Misses       1171     1851     +680     
Flag Coverage Δ
unittests 75.56% <92.45%> (-8.97%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
sbi/inference/snpe/snpe_base.py 90.96% <100.00%> (ø)
sbi/neural_nets/classifier.py 100.00% <100.00%> (ø)
sbi/neural_nets/factory.py 90.00% <100.00%> (+0.52%) ⬆️
sbi/neural_nets/flow.py 93.12% <100.00%> (+0.07%) ⬆️
sbi/neural_nets/mdn.py 100.00% <100.00%> (ø)
sbi/utils/__init__.py 100.00% <ø> (ø)
sbi/utils/user_input_checks.py 83.51% <ø> (+0.26%) ⬆️
sbi/inference/snle/snle_base.py 93.81% <0.00%> (ø)
sbi/inference/snre/snre_base.py 94.44% <0.00%> (ø)
sbi/utils/nn_utils.py 88.23% <88.23%> (ø)

... and 25 files with indirect coverage changes

@janfb janfb force-pushed the fix-device-warnings branch from 06f83d5 to 3a82707 Compare July 2, 2024 09:38
@janfb janfb requested a review from michaeldeistler July 2, 2024 09:40
Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! A few questions and (optional) suggestions below.

sbi/utils/nn_utils.py Show resolved Hide resolved
sbi/utils/nn_utils.py Show resolved Hide resolved
tests/inference_on_device_test.py Show resolved Hide resolved
@janfb janfb requested a review from michaeldeistler July 2, 2024 11:59
@janfb janfb force-pushed the fix-device-warnings branch from 479c0bb to 5afba70 Compare July 2, 2024 12:10
Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great now, thanks a lot! I think there is one small bug, please have a look. Good to go afterwards though.

sbi/neural_nets/flow.py Outdated Show resolved Hide resolved
@janfb janfb merged commit 6fd2a6b into main Jul 2, 2024
6 checks passed
@janfb janfb deleted the fix-device-warnings branch July 2, 2024 15:01
@ningyuxin1999
Copy link

Hi, I've been trying to use sbi and torch for not very long time. I used the following script to train SNPE. The input net and training data are all of device "cuda". Based on what you updated, does it mean that the SNPE should only trained on cpu?

...
neural_posterior = posterior_nn(model="maf", embedding_net=net, hidden_features=10, num_transforms=2)
inference = SNPE(prior=prior, device="cuda", density_estimator=neural_posterior)

for e in range(num_epoch):
    errors_train = []
    for i, (scen_idx, inputs_train, targets_train) in enumerate(train_loader):
        inputs_train, targets_train = inputs_train.to(device), targets_train.to(device)
        embedded_sbi = inference.append_simulations(targets_train, inputs_train).train()
        posterior = DirectPosterior(posterior_estimator=embedded_sbi, prior=prior, device="cuda")
        ...

Here are the errors I got from it:

/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/utils/user_input_checks.py:444: UserWarning: Mismatch between the device of the data fed to the embedding_net and the device of the embedding_net's weights. Fed data has device 'cpu' vs embedding_net weights have device 'cuda:0'. Automatically switching the embedding_net's device to 'cpu', which could otherwise be done manually using the line `embedding_net.to('cpu')`.
  warnings.warn(
Traceback (most recent call last):
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 346, in <module>
    logging.info(start_training(args.training_param_path))
  File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 292, in start_training
    best_errors_grouped_val = training_loop(prior = prior, net=net, train_loader=train_loader, validation_loader=validation_loader,
  File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 141, in training_loop
    embedded_sbi = inference.append_simulations(targets_train, inputs_train).train()
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/inference/snpe/snpe_c.py", line 180, in train
    return super().train(**kwargs)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/inference/snpe/snpe_base.py", line 317, in train
    self._neural_net = self._build_neural_net(
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/utils/get_nn_models.py", line 265, in build_fn
    return build_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/neural_nets/flow.py", line 139, in build_maf
    y_numel = embedding_net(batch_y[:1]).numel()
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
    raise RuntimeError("module must have its parameters and buffers "
RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu

I tried with cpu device as well, but got:

  warnings.warn(
Traceback (most recent call last):
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 349, in <module>
    logging.info(start_training(args.training_param_path))
  File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 295, in start_training
    best_errors_grouped_val = training_loop(prior = prior, net=net, train_loader=train_loader, validation_loader=validation_loader,
  File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 131, in training_loop
    inference = SNPE(prior=prior,device=device, density_estimator=neural_posterior)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/inference/snpe/snpe_c.py", line 84, in __init__
    super().__init__(**kwargs)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/inference/snpe/snpe_base.py", line 64, in __init__
    super().__init__(
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/inference/base.py", line 111, in __init__
    self._device = process_device(device)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/utils/torchutils.py", line 48, in process_device
    torch.cuda.set_device(device)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/torch/cuda/__init__.py", line 397, in set_device
    device = _get_device_index(device)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/torch/cuda/_utils.py", line 34, in _get_device_index
    raise ValueError(f"Expected a cuda device, but got: {device}")
ValueError: Expected a cuda device, but got: cpu

I'm not sure if I understood it correctly, could you maybe help? I really appriciate that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

hardcoded transfer of data and model to cpu triggers UserWarning
3 participants