Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare for ShardedTensor deprecation #16892

Merged
merged 5 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the `lightning.pytorch.strategies.DDPSpawnStrategy` in favor of `DDPStrategy(start_method='spawn')` (merged both classes) ([#16809](https://github.com/Lightning-AI/lightning/pull/16809))


- Removed registration of `ShardedTensor` state dict hooks in `LightningModule.__init__` with `torch>=2.1` ([#16892](https://github.com/Lightning-AI/lightning/pull/16892))



### Fixed

- Fixed an issue where `DistributedSampler.set_epoch` wasn't getting called during `trainer.predict` ([#16785](https://github.com/Lightning-AI/lightning/pull/16785), [#16826](https://github.com/Lightning-AI/lightning/pull/16826))
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.utilities.distributed import _distributed_available, _sync_ddp
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.wrappers import _FabricOptimizer
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.core.hooks import CheckpointHooks, DataHooks, ModelHooks
Expand Down Expand Up @@ -1440,6 +1440,9 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:

These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
"""
if _TORCH_GREATER_EQUAL_2_1:
# ShardedTensor is deprecated in favor of DistributedTensor
return
if _IS_WINDOWS or not torch.distributed.is_available():
rank_zero_debug("Could not register sharded tensor state dict hooks")
return
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def assert_device(device: torch.device) -> None:
assert_device(torch.device("cpu"))


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, max_torch="2.1.0")
def test_sharded_tensor_state_dict(single_process_pg):
from torch.distributed._shard.sharded_tensor import empty as sharded_tensor_empty
from torch.distributed._sharding_spec import ChunkShardingSpec
Expand Down