diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index deaf3a2f016d2..c235045580dd7 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -11,6 +11,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - +- Removed registration of `ShardedTensor` state dict hooks in `LightningModule.__init__` with `torch>=2.1` ([#16892](https://github.com/Lightning-AI/lightning/pull/16892)) + + + ### Fixed - Fixed `num_nodes` not being set for `DDPFullyShardedNativeStrategy` ([#17160](https://github.com/Lightning-AI/lightning/pull/17160)) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 88f3c9cca81ac..3cb578e095655 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -36,7 +36,12 @@ from lightning_fabric.utilities.cloud_io import get_filesystem from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning_fabric.utilities.distributed import _distributed_available, _sync_ddp -from lightning_fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_2_0 +from lightning_fabric.utilities.imports import ( + _IS_WINDOWS, + _TORCH_GREATER_EQUAL_1_11, + _TORCH_GREATER_EQUAL_2_0, + _TORCH_GREATER_EQUAL_2_1, +) from lightning_fabric.utilities.types import Steppable from lightning_fabric.wrappers import _FabricOptimizer from pytorch_lightning.callbacks.callback import Callback @@ -2018,6 +2023,9 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None: These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly. """ + if _TORCH_GREATER_EQUAL_2_1: + # ShardedTensor is deprecated in favor of DistributedTensor + return if _IS_WINDOWS or not torch.distributed.is_available(): rank_zero_debug("Could not register sharded tensor state dict hooks") return diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index cc7c48a4b9806..8490ad26f769b 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -311,7 +311,7 @@ def assert_device(device: torch.device) -> None: assert_device(torch.device("cpu")) -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, max_torch="2.1.0") def test_sharded_tensor_state_dict(single_process_pg): if _TORCH_GREATER_EQUAL_1_11: from torch.distributed._shard.sharded_tensor import empty as sharded_tensor_empty