Skip to content

Commit

Permalink
Prepare for ShardedTensor deprecation
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Feb 28, 2023
1 parent 7bc39ae commit cab7d42
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@
_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0", use_base_version=True)
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
5 changes: 4 additions & 1 deletion src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.utilities.distributed import _distributed_available, _sync_ddp
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.wrappers import _FabricOptimizer
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.core.hooks import CheckpointHooks, DataHooks, ModelHooks
Expand Down Expand Up @@ -1436,6 +1436,9 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
"""
if _TORCH_GREATER_EQUAL_2_1:
# ShardedTensor is deprecated in favor of DistributedTensor
return
if _IS_WINDOWS or not torch.distributed.is_available():
rank_zero_debug("Could not register sharded tensor state dict hooks")
return
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def assert_device(device: torch.device) -> None:
assert_device(torch.device("cpu"))


@RunIf(skip_windows=True)
@RunIf(skip_windows=True, max_torch="2.1.0")
def test_sharded_tensor_state_dict(single_process_pg):
from torch.distributed._shard.sharded_tensor import empty as sharded_tensor_empty
from torch.distributed._sharding_spec import ChunkShardingSpec
Expand Down

0 comments on commit cab7d42

Please sign in to comment.