From cab7d422700845f953f5a8e5d7ee006ccefe5b97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 28 Feb 2023 02:55:07 +0100 Subject: [PATCH 1/2] Prepare for ShardedTensor deprecation --- src/lightning/fabric/utilities/imports.py | 1 + src/lightning/pytorch/core/module.py | 5 ++++- tests/tests_pytorch/core/test_lightning_module.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index a58a940152bbf..b5b1441781640 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -28,3 +28,4 @@ _TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0") _TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0") _TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0", use_base_version=True) +_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index a68314ad5ca7f..495efaadd11ab 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -34,7 +34,7 @@ from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.fabric.utilities.distributed import _distributed_available, _sync_ddp -from lightning.fabric.utilities.imports import _IS_WINDOWS +from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_1 from lightning.fabric.wrappers import _FabricOptimizer from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.core.hooks import CheckpointHooks, DataHooks, ModelHooks @@ -1436,6 +1436,9 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None: These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly. """ + if _TORCH_GREATER_EQUAL_2_1: + # ShardedTensor is deprecated in favor of DistributedTensor + return if _IS_WINDOWS or not torch.distributed.is_available(): rank_zero_debug("Could not register sharded tensor state dict hooks") return diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index c7ad77f0dcbfb..2b4b2c5cb201e 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -310,7 +310,7 @@ def assert_device(device: torch.device) -> None: assert_device(torch.device("cpu")) -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, max_torch="2.1.0") def test_sharded_tensor_state_dict(single_process_pg): from torch.distributed._shard.sharded_tensor import empty as sharded_tensor_empty from torch.distributed._sharding_spec import ChunkShardingSpec From a6f956b1a7a167c697b85c859aa996438ef720f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 28 Feb 2023 02:58:40 +0100 Subject: [PATCH 2/2] CHANGELOG --- src/lightning/pytorch/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 6c2ff7f0509b9..c8dc5b6ff8be2 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -386,6 +386,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the `lightning.pytorch.strategies.DDPSpawnStrategy` in favor of `DDPStrategy(start_method='spawn')` (merged both classes) ([#16809](https://github.com/Lightning-AI/lightning/pull/16809)) +- Removed registration of `ShardedTensor` state dict hooks in `LightningModule.__init__` with `torch>=2.1` ([#16892](https://github.com/Lightning-AI/lightning/pull/16892)) + + ### Fixed