diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index c428091b58df9..06e60db1e283a 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -71,12 +71,11 @@ def auto_transfer_args(self, *args, **kwargs): def parameter_validation(fn: Callable) -> Callable: """ - Decorator for :meth:`~pytorch_lightning.core.LightningModule.to` method. Validates that the module parameter lengths match after moving to the device. It is useful when tying weights on TPU's. Args: - fn: ``.to`` method + fn: ``model_to_device`` method Note: TPU's require weights to be tied/shared after moving the module to the device. @@ -90,10 +89,10 @@ def parameter_validation(fn: Callable) -> Callable: @wraps(fn) def inner_fn(self, *args, **kwargs): - pre_layer_count = len(list(self.parameters())) + pre_layer_count = len(list(self.model.parameters())) module = fn(self, *args, **kwargs) - self.on_post_move_to_device() - post_layer_count = len(list(self.parameters())) + self.model.on_post_move_to_device() + post_layer_count = len(list(self.model.parameters())) if not pre_layer_count == post_layer_count: rank_zero_warn( diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index a61dd1cbc5dbd..99abff992ebeb 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -15,6 +15,7 @@ import torch +from pytorch_lightning.core.decorators import parameter_validation from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import move_data_to_device @@ -39,6 +40,7 @@ def __init__(self, device: int, debug: bool = False): def is_distributed(self) -> bool: return False + @parameter_validation def model_to_device(self) -> None: self.model.to(self.root_device) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 8a93faa0281cd..9ac1e757b2b6d 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -23,6 +23,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl +from pytorch_lightning.core.decorators import parameter_validation from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader @@ -171,6 +172,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: if self.local_rank == 0: time.sleep(2) + @parameter_validation def model_to_device(self) -> None: self.model = self.wrapped_model.to(self.root_device) diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index a6a26b142bc16..13f16d9b426ac 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -17,8 +17,6 @@ import torch from torch.nn import Module -from pytorch_lightning.core.decorators import parameter_validation - class DeviceDtypeModuleMixin(Module): __jit_unused_properties__ = ['device', 'dtype'] @@ -47,7 +45,6 @@ def device(self) -> Union[str, torch.device]: return device - @parameter_validation def to(self, *args, **kwargs) -> Module: """Moves and/or casts the parameters and buffers. @@ -84,9 +81,6 @@ def to(self, *args, **kwargs) -> Module: ... def __init__(self, weight: torch.Tensor): ... super().__init__() ... self.register_buffer('weight', weight) - ... - ... def on_post_move_to_device(self): - ... pass >>> _ = torch.manual_seed(0) >>> module = ExampleModule(torch.rand(3, 4)) >>> module.weight #doctest: +ELLIPSIS diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py index 2d096ee6be2a7..b57894816090d 100644 --- a/tests/accelerators/test_tpu_backend.py +++ b/tests/accelerators/test_tpu_backend.py @@ -95,25 +95,21 @@ def test_weight_tying_warning(tmpdir, capsys=None): trainer.fit(model) -# @RunIf(tpu=True) -# @pl_multi_process_test -# def test_if_weights_tied(tmpdir, capsys=None): -# """ -# Test if weights are properly tied on `on_post_move_to_device`. -# Ensure no warning for parameter mismatch is thrown. -# """ - -# # TODO (kaushikb11): Add `parameter_validation` specific to TPU Accelerators -# class Model(WeightSharingModule): +@RunIf(tpu=True) +@pl_multi_process_test +def test_if_weights_tied(tmpdir, capsys=None): + """ + Test if weights are properly tied on `on_post_move_to_device`. + Ensure no warning for parameter mismatch is thrown. + """ -# def on_post_move_to_device(self): -# self.layer_3.weight = self.layer_1.weight + class Model(WeightSharingModule): -# model = Model() -# trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) + def on_post_move_to_device(self): + self.layer_3.weight = self.layer_1.weight -# with pytest.warns(UserWarning) as warnings: -# trainer.fit(model) + model = Model() + trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) -# assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list)) -# assert len(trainer.test(model)) == 1 + with pytest.warns(UserWarning, match="The model layers do not match"): + trainer.fit(model)