Skip to content

Commit

Permalink
Remove deprecated LightningModule.on_post_move_to_device (Lightning…
Browse files Browse the repository at this point in the history
…-AI#13548)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
2 people authored and jerome-habana committed Jul 14, 2022
1 parent 604f7ca commit 692da6a
Show file tree
Hide file tree
Showing 8 changed files with 10 additions and 77 deletions.
18 changes: 6 additions & 12 deletions docs/source-pytorch/accelerators/tpu_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ Weight Tying/Sharing is a technique where in the module weights are shared among
This is a common method to reduce memory consumption and is utilized in many State of the Art
architectures today.

PyTorch XLA requires these weights to be tied/shared after moving the model
to the TPU device. To support this requirement Lightning provides a model hook which is
called after the model is moved to the device. Any weights that require to be tied should
be done in the `on_post_move_to_device` model hook. This will ensure that the weights
among the modules are shared and not copied.
PyTorch XLA requires these weights to be tied/shared after moving the model to the XLA device.
To support this requirement, Lightning automatically finds these weights and ties them after
the modules are moved to the XLA device under the hood. It will ensure that the weights among
the modules are shared but not copied independently.

PyTorch Lightning has an inbuilt check which verifies that the model parameter lengths
match once the model is moved to the device. If the lengths do not match Lightning
Expand All @@ -37,9 +36,8 @@ Example:
self.layer_1 = nn.Linear(32, 10, bias=False)
self.layer_2 = nn.Linear(10, 32, bias=False)
self.layer_3 = nn.Linear(32, 10, bias=False)
# TPU shared weights are copied independently
# on the XLA device and this line won't have any effect.
# However, it works fine for CPU and GPU.
# Lightning automatically ties these weights after moving to the XLA device,
# so all you need is to write the following just like on other accelerators.
self.layer_3.weight = self.layer_1.weight
def forward(self, x):
Expand All @@ -48,10 +46,6 @@ Example:
x = self.layer_3(x)
return x
def on_post_move_to_device(self):
# Weights shared after the model has been moved to TPU Device
self.layer_3.weight = self.layer_1.weight
model = WeightSharingModule()
trainer = Trainer(max_epochs=1, accelerator="tpu", devices=8)
Expand Down
6 changes: 0 additions & 6 deletions docs/source-pytorch/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1501,12 +1501,6 @@ on_validation_epoch_end
.. automethod:: pytorch_lightning.core.module.LightningModule.on_validation_epoch_end
:noindex:

on_post_move_to_device
~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.module.LightningModule.on_post_move_to_device
:noindex:

configure_sharded_model
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated `Callback.on_keyboard_interrupt` ([#13438](https://github.com/Lightning-AI/lightning/pull/13438))


- Removed deprecated `LightningModule.on_post_move_to_device` ([#13548](https://github.com/Lightning-AI/lightning/pull/13548))


### Fixed


Expand Down
15 changes: 0 additions & 15 deletions src/pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,21 +298,6 @@ def on_before_optimizer_step(self, optimizer, optimizer_idx):
)
"""

def on_post_move_to_device(self) -> None:
"""Called in the ``parameter_validation`` decorator after
:meth:`~pytorch_lightning.core.LightningModule.to` is called. This is a good place to tie weights between
modules after moving them to a device. Can be used when training models with weight sharing properties on
TPU.
Addresses the handling of shared weights on TPU:
https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks
Example::
def on_post_move_to_device(self):
self.decoder.weight = self.encoder.weight
"""

def configure_sharded_model(self) -> None:
"""Hook to create modules in a distributed aware context. This is useful for when using sharded plugins,
where we'd like to shard the model instantly, which is useful for extremely large models which can save
Expand Down
6 changes: 0 additions & 6 deletions src/pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ def predict_step(self, *args: Any, **kwargs: Any) -> Any:
def forward(self, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError

def on_post_move_to_device(self) -> None:
pass


class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
Expand Down Expand Up @@ -95,9 +92,6 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
return self.module.predict_step(*inputs, **kwargs)
return self.module(*inputs, **kwargs)

def on_post_move_to_device(self) -> None:
pass


def unwrap_lightning_module(wrapped_model: nn.Module) -> "pl.LightningModule":
"""Recursively unwraps a :class:`~pytorch_lightning.core.module.LightningModule` by following the ``.module``
Expand Down
7 changes: 1 addition & 6 deletions src/pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from pytorch_lightning.utilities.seed import reset_seed
Expand Down Expand Up @@ -124,11 +123,7 @@ def setup(self, trainer: "pl.Trainer") -> None:

shared_params = find_shared_parameters(self.model)
self.model_to_device()
if is_overridden("on_post_move_to_device", self.lightning_module):
self.model.module.on_post_move_to_device()
else:
set_shared_parameters(self.model.module, shared_params)

set_shared_parameters(self.model.module, shared_params)
self.setup_precision_plugin()

if trainer.state.fn == TrainerFn.FITTING:
Expand Down
16 changes: 0 additions & 16 deletions src/pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
__verify_eval_loop_configuration(trainer, model, "predict")

__verify_dp_batch_transfer_support(trainer, model)
# TODO: Delete _check_on_post_move_to_device in v1.7
_check_on_post_move_to_device(model)
_check_deprecated_callback_hooks(trainer)
# TODO: Delete _check_on_hpc_hooks in v1.8
_check_on_hpc_hooks(model)
Expand Down Expand Up @@ -122,20 +120,6 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh
)


def _check_on_post_move_to_device(model: "pl.LightningModule") -> None:
r"""
Checks if `on_post_move_to_device` method is overridden and sends a deprecation warning.
Args:
model: The model to check the `on_post_move_to_device` method.
"""
if is_overridden("on_post_move_to_device", model):
rank_zero_deprecation(
"Method `on_post_move_to_device` has been deprecated in v1.5 and will be removed in v1.7. "
"We perform automatic parameters tying without the need of implementing `on_post_move_to_device`."
)


def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule", stage: str) -> None:
loader_name = f"{stage}_dataloader"
step_name = "validation_step" if stage == "val" else f"{stage}_step"
Expand Down
16 changes: 0 additions & 16 deletions tests/tests_pytorch/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.plugins.environments import (
KubeflowEnvironment,
LightningEnvironment,
Expand All @@ -39,21 +38,6 @@ def test_v1_7_0_deprecate_lightning_distributed(tmpdir):
_ = LightningDistributed()


def test_v1_7_0_deprecate_on_post_move_to_device(tmpdir):
class TestModel(BoringModel):
def on_post_move_to_device(self):
print("on_post_move_to_device")

model = TestModel()

trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=5, max_epochs=1)

with pytest.deprecated_call(
match=r"Method `on_post_move_to_device` has been deprecated in v1.5 and will be removed in v1.7"
):
trainer.fit(model)


def test_v1_7_0_deprecated_max_steps_none(tmpdir):
with pytest.deprecated_call(match="`max_steps = None` is deprecated in v1.5"):
_ = Trainer(max_steps=None)
Expand Down

0 comments on commit 692da6a

Please sign in to comment.