Skip to content

Commit

Permalink
Automatically get device from first model parameter if device param…
Browse files Browse the repository at this point in the history
…eter to `summary` is `None` (#211)

* introduce `get_device` to torchinfo.py

- use it to determine device if no device is given to summary
- Automatically determines device of model so that it does not have to be moved in case of multi-GPU-setup.
- shoud fix [issue #209](#209). Test to check that to be written on different device. This commit is to set up the functionality without breaking old tests.

* add tests for multi-gpu to gpu_test.py
  • Loading branch information
snimu authored Jan 13, 2023
1 parent b4d80c0 commit 8305e1f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
21 changes: 21 additions & 0 deletions tests/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,24 @@ def test_input_size_half_precision() -> None:
),
):
summary(test, dtypes=[torch.float16], input_size=(10, 2), device="cuda")


@pytest.mark.skipif(
not torch.cuda.device_count() >= 2, reason="Only relevant for multi-GPU"
)
class TestMultiGPU:
"""multi-GPU-only tests"""

@staticmethod
def test_model_stays_on_device_if_gpu() -> None:
model = torch.nn.Linear(10, 10).to("cuda:1")
summary(model)
model_parameter = next(model.parameters())
assert model_parameter.device == torch.device("cuda:1")

@staticmethod
def test_different_model_parts_on_different_devices() -> None:
model = torch.nn.Sequential(
torch.nn.Linear(10, 10).to(1), torch.nn.Linear(10, 10).to(0)
)
summary(model)
17 changes: 16 additions & 1 deletion torchinfo/torchinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class name as the key. If the forward pass is an expensive operation,
cache_forward_pass = False

if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = get_device(model)

validate_user_params(
input_data, input_size, columns, col_width, device, dtypes, verbose
Expand Down Expand Up @@ -446,6 +446,21 @@ def set_device(data: Any, device: torch.device | str) -> Any:
)


def get_device(model: nn.Module) -> torch.device | str:
"""
Gets device of first parameter of model and returns it if it is on cuda,
otherwise returns cuda if available or cpu if not.
"""
try:
model_parameter = next(model.parameters())
except StopIteration:
model_parameter = None

if model_parameter is not None and model_parameter.is_cuda:
return model_parameter.device
return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_input_data_sizes(data: Any) -> Any:
"""
Converts input data to an equivalent data structure of torch.Sizes
Expand Down

0 comments on commit 8305e1f

Please sign in to comment.