diff --git a/tests/gpu_test.py b/tests/gpu_test.py index 0e9aa93..1c67915 100644 --- a/tests/gpu_test.py +++ b/tests/gpu_test.py @@ -40,3 +40,24 @@ def test_input_size_half_precision() -> None: ), ): summary(test, dtypes=[torch.float16], input_size=(10, 2), device="cuda") + + +@pytest.mark.skipif( + not torch.cuda.device_count() >= 2, reason="Only relevant for multi-GPU" +) +class TestMultiGPU: + """multi-GPU-only tests""" + + @staticmethod + def test_model_stays_on_device_if_gpu() -> None: + model = torch.nn.Linear(10, 10).to("cuda:1") + summary(model) + model_parameter = next(model.parameters()) + assert model_parameter.device == torch.device("cuda:1") + + @staticmethod + def test_different_model_parts_on_different_devices() -> None: + model = torch.nn.Sequential( + torch.nn.Linear(10, 10).to(1), torch.nn.Linear(10, 10).to(0) + ) + summary(model) diff --git a/torchinfo/torchinfo.py b/torchinfo/torchinfo.py index a905c0e..73374a1 100644 --- a/torchinfo/torchinfo.py +++ b/torchinfo/torchinfo.py @@ -206,7 +206,7 @@ class name as the key. If the forward pass is an expensive operation, cache_forward_pass = False if device is None: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = get_device(model) validate_user_params( input_data, input_size, columns, col_width, device, dtypes, verbose @@ -446,6 +446,21 @@ def set_device(data: Any, device: torch.device | str) -> Any: ) +def get_device(model: nn.Module) -> torch.device | str: + """ + Gets device of first parameter of model and returns it if it is on cuda, + otherwise returns cuda if available or cpu if not. + """ + try: + model_parameter = next(model.parameters()) + except StopIteration: + model_parameter = None + + if model_parameter is not None and model_parameter.is_cuda: + return model_parameter.device + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def get_input_data_sizes(data: Any) -> Any: """ Converts input data to an equivalent data structure of torch.Sizes