From 0fc4dde4a9a271bee959f19c1fb0623369e546da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 4 Aug 2024 18:24:25 -0400 Subject: [PATCH 1/8] update --- .../pytorch/utilities/model_summary/model_summary.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py index 0f48bee191c7b..070f6d6ea0d28 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py @@ -27,6 +27,7 @@ import lightning.pytorch as pl from lightning.pytorch.utilities.model_helpers import _ModuleMode from lightning.pytorch.utilities.rank_zero import WarningCache +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 log = logging.getLogger(__name__) warning_cache = WarningCache() @@ -473,6 +474,12 @@ def get_human_readable_count(number: int) -> str: def _is_lazy_weight_tensor(p: Tensor) -> bool: from torch.nn.parameter import UninitializedParameter + if _TORCH_GREATER_EQUAL_2_4: + from torch.distributed._tensor import DTensor + + if isinstance(p, DTensor): + return False + if isinstance(p, UninitializedParameter): warning_cache.warn( "The total number of parameters detected may be inaccurate because the model contains" From 0a69fcff89cc353a8a80dad92ccb876021b8e825 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 5 Aug 2024 00:31:47 +0200 Subject: [PATCH 2/8] update --- src/lightning/fabric/utilities/distributed.py | 12 ++++++++++- .../utilities/model_summary/model_summary.py | 20 ++++++------------- .../model_summary/model_summary_deepspeed.py | 10 +++++----- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 75b2f7c580b6f..e39bcf37015f6 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -7,7 +7,7 @@ from contextlib import nullcontext from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union +from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union, TypeGuard import torch import torch.nn.functional as F @@ -20,6 +20,7 @@ from lightning.fabric.utilities.data import _num_cpus_available from lightning.fabric.utilities.rank_zero import rank_zero_info from lightning.fabric.utilities.types import _PATH, ReduceOp +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 if torch.distributed.is_available(): from torch.distributed import group @@ -32,6 +33,7 @@ class group: # type: ignore if TYPE_CHECKING: from lightning.fabric.plugins import ClusterEnvironment from lightning.fabric.strategies import Strategy + from torch.distributed._tensor import DTensor log = logging.getLogger(__name__) @@ -427,3 +429,11 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.barrier() if self.group is not None: torch.distributed.destroy_process_group(self.group) + + +def _is_dtensor(tensor: Tensor) -> TypeGuard["DTensor"]: + if _TORCH_GREATER_EQUAL_2_4: + from torch.distributed._tensor import DTensor + + return isinstance(tensor, DTensor) + return False diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py index 070f6d6ea0d28..fe83042eb0eb6 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py @@ -25,9 +25,9 @@ from torch.utils.hooks import RemovableHandle import lightning.pytorch as pl +from lightning.fabric.utilities.distributed import _is_dtensor from lightning.pytorch.utilities.model_helpers import _ModuleMode from lightning.pytorch.utilities.rank_zero import WarningCache -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 log = logging.getLogger(__name__) warning_cache = WarningCache() @@ -136,7 +136,7 @@ def layer_type(self) -> str: @property def num_parameters(self) -> int: """Returns the number of parameters in this module.""" - return sum(math.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters()) + return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._module.parameters()) @property def training(self) -> bool: @@ -265,13 +265,11 @@ def total_training_modes(self) -> Dict[str, int]: @property def total_parameters(self) -> int: - return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters()) + return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters()) @property def trainable_parameters(self) -> int: - return sum( - p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters() if p.requires_grad - ) + return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters() if p.requires_grad) @property def total_layer_params(self) -> int: @@ -471,16 +469,10 @@ def get_human_readable_count(number: int) -> str: return f"{number:,.1f} {labels[index]}" -def _is_lazy_weight_tensor(p: Tensor) -> bool: +def _tensor_has_shape(p: Tensor) -> bool: from torch.nn.parameter import UninitializedParameter - if _TORCH_GREATER_EQUAL_2_4: - from torch.distributed._tensor import DTensor - - if isinstance(p, DTensor): - return False - - if isinstance(p, UninitializedParameter): + if isinstance(p, UninitializedParameter) and not _is_dtensor(p): warning_cache.warn( "The total number of parameters detected may be inaccurate because the model contains" " an instance of `UninitializedParameter`. To get an accurate number, set `self.example_input_array`" diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py index c3c9cfe9823a3..57d9ae5024b58 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py @@ -25,7 +25,7 @@ NOT_APPLICABLE, LayerSummary, ModelSummary, - _is_lazy_weight_tensor, + _tensor_has_shape, get_human_readable_count, ) @@ -40,7 +40,7 @@ class DeepSpeedLayerSummary(LayerSummary): @override def num_parameters(self) -> int: """Returns the number of parameters in this module.""" - return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters()) + return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters()) @property def average_shard_parameters(self) -> int: @@ -49,7 +49,7 @@ def average_shard_parameters(self) -> int: def partitioned_size(p: Parameter) -> int: return p.partitioned_size() if RequirementCache("deepspeed<0.6.6") else p.partition_numel() - return sum(partitioned_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters()) + return sum(partitioned_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters()) class DeepSpeedSummary(ModelSummary): @@ -71,13 +71,13 @@ def summarize(self) -> Dict[str, DeepSpeedLayerSummary]: # type: ignore[overrid @property @override def total_parameters(self) -> int: - return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters()) + return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._model.parameters()) @property @override def trainable_parameters(self) -> int: return sum( - deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 + deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._model.parameters() if p.requires_grad ) From cf1ad7d160d45cd146e639101037de0023d2347b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 5 Aug 2024 01:14:37 +0200 Subject: [PATCH 3/8] add simple test --- .../utilities/test_model_summary.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/tests_pytorch/utilities/test_model_summary.py b/tests/tests_pytorch/utilities/test_model_summary.py index 00fdf77d4cdfd..c694da06c0dfb 100644 --- a/tests/tests_pytorch/utilities/test_model_summary.py +++ b/tests/tests_pytorch/utilities/test_model_summary.py @@ -13,6 +13,8 @@ # limitations under the License. from collections import OrderedDict from typing import Any +from unittest import mock +from unittest.mock import Mock, MagicMock import pytest import torch @@ -121,6 +123,18 @@ def forward(self, inp): return self.layer2(self.layer1(inp)) +# class FakeDTensorModel(LightningModule): +# """A model which contains DTensor parameters.""" +# +# def __init__(self): +# super().__init__() +# from torch.distributed._tensor import DTensor +# # self.parameter = nn.Parameter(DTensor.from_local(torch.rand(2, 2))) +# +# self.parameter = MagicMock(wraps=nn.Parameter, spec=DTensor)(torch.rand(2, 2)) +# assert isinstance(self.parameter, DTensor) + + class DeepNestedModel(LightningModule): """A model with deep nested layers.""" @@ -345,6 +359,18 @@ def test_lazy_model_summary(): assert summary.trainable_parameters == 0 +@mock.patch("lightning.pytorch.utilities.model_summary.model_summary._is_dtensor", return_value=True) +def test_dtensor_model_summary(_): + """Test that the model summary can work with layers that have DTensor parameters.""" + # We mock the `_is_dtensor` to pretend parameters are DTensors, because testing with real DTensors + # would require setting up distributed + dtensor_model = UnorderedModel() + summary = ModelSummary(dtensor_model) + assert summary.total_layer_params > 0 + assert summary.total_parameters > 0 + assert summary.trainable_parameters > 0 + + @pytest.mark.parametrize("max_depth", [-1, 0, 1, 3, 999]) def test_max_depth_param(max_depth): """Test that only the modules up to the desired depth are shown.""" From a3a408bb5239089d6bff6a23afcdfb355fefe243 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 5 Aug 2024 01:20:14 +0200 Subject: [PATCH 4/8] add test --- tests/tests_fabric/utilities/test_distributed.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index 2c30b3aa62ddf..486259b451f7a 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -3,9 +3,11 @@ from functools import partial from pathlib import Path from unittest import mock +from unittest.mock import Mock import pytest import torch +import lightning.fabric from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy @@ -19,6 +21,7 @@ _suggested_max_num_threads, _sync_ddp, is_shared_filesystem, + _is_dtensor, ) from lightning_utilities.core.imports import RequirementCache @@ -234,3 +237,14 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock): atexit_mock.reset_mock() _init_dist_connection(LightningEnvironment(), "gloo") atexit_mock.register.assert_not_called() + + +@RunIf(min_torch="2.4") +def test_is_dtensor(monkeypatch): + from torch.distributed._tensor import DTensor + + assert _is_dtensor(Mock(spec=DTensor)) + assert not _is_dtensor(torch.zeros(2, 2)) + + monkeypatch.setattr(lightning.fabric.utilities.distributed, "_TORCH_GREATER_EQUAL_2_4", False) + assert not _is_dtensor(Mock(spec=DTensor)) From 9d6dfc8ac51a77ecab7e544d417622aa0aea10fd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 4 Aug 2024 23:23:37 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/utilities/distributed.py | 7 ++++--- tests/tests_fabric/utilities/test_distributed.py | 4 ++-- tests/tests_pytorch/utilities/test_model_summary.py | 1 - 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index e39bcf37015f6..87c00b022d105 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -7,7 +7,7 @@ from contextlib import nullcontext from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union, TypeGuard +from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, TypeGuard, Union import torch import torch.nn.functional as F @@ -18,9 +18,9 @@ from lightning.fabric.utilities.cloud_io import _is_local_file_protocol from lightning.fabric.utilities.data import _num_cpus_available +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.fabric.utilities.rank_zero import rank_zero_info from lightning.fabric.utilities.types import _PATH, ReduceOp -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 if torch.distributed.is_available(): from torch.distributed import group @@ -31,9 +31,10 @@ class group: # type: ignore if TYPE_CHECKING: + from torch.distributed._tensor import DTensor + from lightning.fabric.plugins import ClusterEnvironment from lightning.fabric.strategies import Strategy - from torch.distributed._tensor import DTensor log = logging.getLogger(__name__) diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index 486259b451f7a..cc6c23bddbd7b 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -5,9 +5,9 @@ from unittest import mock from unittest.mock import Mock +import lightning.fabric import pytest import torch -import lightning.fabric from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy @@ -17,11 +17,11 @@ _gather_all_tensors, _InfiniteBarrier, _init_dist_connection, + _is_dtensor, _set_num_threads_if_needed, _suggested_max_num_threads, _sync_ddp, is_shared_filesystem, - _is_dtensor, ) from lightning_utilities.core.imports import RequirementCache diff --git a/tests/tests_pytorch/utilities/test_model_summary.py b/tests/tests_pytorch/utilities/test_model_summary.py index c694da06c0dfb..d0eed8aadb041 100644 --- a/tests/tests_pytorch/utilities/test_model_summary.py +++ b/tests/tests_pytorch/utilities/test_model_summary.py @@ -14,7 +14,6 @@ from collections import OrderedDict from typing import Any from unittest import mock -from unittest.mock import Mock, MagicMock import pytest import torch From 9e70ae24ca14e1d072acc89d92c69ed11c3141fa Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 5 Aug 2024 01:25:07 +0200 Subject: [PATCH 6/8] update --- src/lightning/pytorch/CHANGELOG.md | 1 + tests/tests_pytorch/utilities/test_model_summary.py | 12 ------------ 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 4d8eebf134b75..eba67ebffcbbe 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -49,6 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814)) +- Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163)) ## [2.3.0] - 2024-06-13 diff --git a/tests/tests_pytorch/utilities/test_model_summary.py b/tests/tests_pytorch/utilities/test_model_summary.py index d0eed8aadb041..cced6546aab75 100644 --- a/tests/tests_pytorch/utilities/test_model_summary.py +++ b/tests/tests_pytorch/utilities/test_model_summary.py @@ -122,18 +122,6 @@ def forward(self, inp): return self.layer2(self.layer1(inp)) -# class FakeDTensorModel(LightningModule): -# """A model which contains DTensor parameters.""" -# -# def __init__(self): -# super().__init__() -# from torch.distributed._tensor import DTensor -# # self.parameter = nn.Parameter(DTensor.from_local(torch.rand(2, 2))) -# -# self.parameter = MagicMock(wraps=nn.Parameter, spec=DTensor)(torch.rand(2, 2)) -# assert isinstance(self.parameter, DTensor) - - class DeepNestedModel(LightningModule): """A model with deep nested layers.""" From 2aed8dbb8db0b370f1ce680c9919194d32f03a6b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 5 Aug 2024 01:31:59 +0200 Subject: [PATCH 7/8] update --- src/lightning/fabric/utilities/distributed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 87c00b022d105..0e6c52dfb09b9 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -7,14 +7,14 @@ from contextlib import nullcontext from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, TypeGuard, Union +from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union import torch import torch.nn.functional as F from lightning_utilities.core.imports import package_available from torch import Tensor from torch.utils.data import Dataset, DistributedSampler, Sampler -from typing_extensions import Self, override +from typing_extensions import Self, TypeGuard, override from lightning.fabric.utilities.cloud_io import _is_local_file_protocol from lightning.fabric.utilities.data import _num_cpus_available From d4cdde4db3d3f59caadf00728b7a88b349494a65 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 5 Aug 2024 01:34:45 +0200 Subject: [PATCH 8/8] add comment --- src/lightning/pytorch/utilities/model_summary/model_summary.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py index fe83042eb0eb6..c40dc94568a51 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py @@ -472,6 +472,7 @@ def get_human_readable_count(number: int) -> str: def _tensor_has_shape(p: Tensor) -> bool: from torch.nn.parameter import UninitializedParameter + # DTensor is a subtype of `UninitializedParameter`, but the shape is known if isinstance(p, UninitializedParameter) and not _is_dtensor(p): warning_cache.warn( "The total number of parameters detected may be inaccurate because the model contains"