Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parameter count in ModelSummary when parameters are DTensors #20163

Merged
merged 8 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from lightning_utilities.core.imports import package_available
from torch import Tensor
from torch.utils.data import Dataset, DistributedSampler, Sampler
from typing_extensions import Self, override
from typing_extensions import Self, TypeGuard, override

from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
from lightning.fabric.utilities.data import _num_cpus_available
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.rank_zero import rank_zero_info
from lightning.fabric.utilities.types import _PATH, ReduceOp

Expand All @@ -30,6 +31,8 @@ class group: # type: ignore


if TYPE_CHECKING:
from torch.distributed._tensor import DTensor

from lightning.fabric.plugins import ClusterEnvironment
from lightning.fabric.strategies import Strategy

Expand Down Expand Up @@ -427,3 +430,11 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.barrier()
if self.group is not None:
torch.distributed.destroy_process_group(self.group)


def _is_dtensor(tensor: Tensor) -> TypeGuard["DTensor"]:
if _TORCH_GREATER_EQUAL_2_4:
from torch.distributed._tensor import DTensor

return isinstance(tensor, DTensor)
return False
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814))

- Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163))


## [2.3.0] - 2024-06-13
Expand Down
14 changes: 7 additions & 7 deletions src/lightning/pytorch/utilities/model_summary/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch.utils.hooks import RemovableHandle

import lightning.pytorch as pl
from lightning.fabric.utilities.distributed import _is_dtensor
from lightning.pytorch.utilities.model_helpers import _ModuleMode
from lightning.pytorch.utilities.rank_zero import WarningCache

Expand Down Expand Up @@ -135,7 +136,7 @@ def layer_type(self) -> str:
@property
def num_parameters(self) -> int:
"""Returns the number of parameters in this module."""
return sum(math.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._module.parameters())

@property
def training(self) -> bool:
Expand Down Expand Up @@ -264,13 +265,11 @@ def total_training_modes(self) -> Dict[str, int]:

@property
def total_parameters(self) -> int:
return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters())

@property
def trainable_parameters(self) -> int:
return sum(
p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters() if p.requires_grad
)
return sum(p.numel() if not _tensor_has_shape(p) else 0 for p in self._model.parameters() if p.requires_grad)

@property
def total_layer_params(self) -> int:
Expand Down Expand Up @@ -470,10 +469,11 @@ def get_human_readable_count(number: int) -> str:
return f"{number:,.1f} {labels[index]}"


def _is_lazy_weight_tensor(p: Tensor) -> bool:
def _tensor_has_shape(p: Tensor) -> bool:
from torch.nn.parameter import UninitializedParameter

if isinstance(p, UninitializedParameter):
# DTensor is a subtype of `UninitializedParameter`, but the shape is known
if isinstance(p, UninitializedParameter) and not _is_dtensor(p):
warning_cache.warn(
"The total number of parameters detected may be inaccurate because the model contains"
" an instance of `UninitializedParameter`. To get an accurate number, set `self.example_input_array`"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
NOT_APPLICABLE,
LayerSummary,
ModelSummary,
_is_lazy_weight_tensor,
_tensor_has_shape,
get_human_readable_count,
)

Expand All @@ -40,7 +40,7 @@ class DeepSpeedLayerSummary(LayerSummary):
@override
def num_parameters(self) -> int:
"""Returns the number of parameters in this module."""
return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters())

@property
def average_shard_parameters(self) -> int:
Expand All @@ -49,7 +49,7 @@ def average_shard_parameters(self) -> int:
def partitioned_size(p: Parameter) -> int:
return p.partitioned_size() if RequirementCache("deepspeed<0.6.6") else p.partition_numel()

return sum(partitioned_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())
return sum(partitioned_size(p) if not _tensor_has_shape(p) else 0 for p in self._module.parameters())


class DeepSpeedSummary(ModelSummary):
Expand All @@ -71,13 +71,13 @@ def summarize(self) -> Dict[str, DeepSpeedLayerSummary]: # type: ignore[overrid
@property
@override
def total_parameters(self) -> int:
return sum(deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
return sum(deepspeed_param_size(p) if not _tensor_has_shape(p) else 0 for p in self._model.parameters())

@property
@override
def trainable_parameters(self) -> int:
return sum(
deepspeed_param_size(p) if not _is_lazy_weight_tensor(p) else 0
deepspeed_param_size(p) if not _tensor_has_shape(p) else 0
for p in self._model.parameters()
if p.requires_grad
)
Expand Down
14 changes: 14 additions & 0 deletions tests/tests_fabric/utilities/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from functools import partial
from pathlib import Path
from unittest import mock
from unittest.mock import Mock

import lightning.fabric
import pytest
import torch
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
Expand All @@ -15,6 +17,7 @@
_gather_all_tensors,
_InfiniteBarrier,
_init_dist_connection,
_is_dtensor,
_set_num_threads_if_needed,
_suggested_max_num_threads,
_sync_ddp,
Expand Down Expand Up @@ -234,3 +237,14 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
atexit_mock.reset_mock()
_init_dist_connection(LightningEnvironment(), "gloo")
atexit_mock.register.assert_not_called()


@RunIf(min_torch="2.4")
def test_is_dtensor(monkeypatch):
from torch.distributed._tensor import DTensor

assert _is_dtensor(Mock(spec=DTensor))
assert not _is_dtensor(torch.zeros(2, 2))

monkeypatch.setattr(lightning.fabric.utilities.distributed, "_TORCH_GREATER_EQUAL_2_4", False)
assert not _is_dtensor(Mock(spec=DTensor))
13 changes: 13 additions & 0 deletions tests/tests_pytorch/utilities/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from collections import OrderedDict
from typing import Any
from unittest import mock

import pytest
import torch
Expand Down Expand Up @@ -345,6 +346,18 @@ def test_lazy_model_summary():
assert summary.trainable_parameters == 0


@mock.patch("lightning.pytorch.utilities.model_summary.model_summary._is_dtensor", return_value=True)
def test_dtensor_model_summary(_):
"""Test that the model summary can work with layers that have DTensor parameters."""
# We mock the `_is_dtensor` to pretend parameters are DTensors, because testing with real DTensors
# would require setting up distributed
dtensor_model = UnorderedModel()
summary = ModelSummary(dtensor_model)
assert summary.total_layer_params > 0
assert summary.total_parameters > 0
assert summary.trainable_parameters > 0


@pytest.mark.parametrize("max_depth", [-1, 0, 1, 3, 999])
def test_max_depth_param(max_depth):
"""Test that only the modules up to the desired depth are shown."""
Expand Down
Loading