From eebc4ded876ea4b1aff70788b76fce52c595d104 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Sun, 10 Jan 2021 16:44:39 +0530 Subject: [PATCH] docs and refactors --- docs/source/tpu.rst | 54 +++++++++++++++++++++++++++- pytorch_lightning/core/decorators.py | 17 +++++++++ tests/backends/test_tpu_backend.py | 34 ++---------------- tests/base/weight_sharing_module.py | 18 ++++++++++ 4 files changed, 91 insertions(+), 32 deletions(-) create mode 100644 tests/base/weight_sharing_module.py diff --git a/docs/source/tpu.rst b/docs/source/tpu.rst index 5f4c48076d813..89a66efeec27e 100644 --- a/docs/source/tpu.rst +++ b/docs/source/tpu.rst @@ -192,7 +192,59 @@ set the 16-bit flag. Under the hood the xla library will use the `bfloat16 type `_. ----------------- + +----------------- + +Weight Sharing/Tying +----------------------- +Weight Tying/Sharing is a technique where in the module weights are shared among two or more layers. +This is a common method to reduce memory consumption and is utilized in many State of the Art +architectures today. + +PyTorch XLA requires these weights to be tied/shared after moving the model +to the TPU device. To support this requirement Lightning provides a model hook which is +called after the model is moved to the device. Any weights that require to be tied should +be done in the `on_post_move_to_device` model hook. This will ensure that the weights +among the modules are shared and not copied. + +PyTorch Lightning has an inbuilt check which verifies that the model parameter lengths +match once the model is moved to the device. If the lengths do not match Lightning +throws a warning message. + +Example: + +.. code-block:: python + + import pytorch_lightning as pl + from torch import nn + + + class WeightSharingModule(pl.LightningModule): + def __init__(self): + super().__init__() + self.layer_1 = nn.Linear(32, 10, bias=False) + self.layer_2 = nn.Linear(10, 32, bias=False) + self.layer_3 = nn.Linear(32, 10, bias=False) + self.layer_3.weight = self.layer_1.weight # Weights will be copied on TPU + + def forward(self, x): + x = self.layer_1(x) + x = self.layer_2(x) + x = self.layer_3(x) + return x + + def on_post_move_to_device(self): + # Weights shared after the model has been moved to TPU Device + self.layer_3.weight = self.layer_1.weight + + + model = WeightSharingModule() + trainer = Trainer(max_epochs=1, tpu_cores=8) + result = trainer.fit(model) + +See `XLA Documentation `_ + +----------------------- About XLA ---------- diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index af9d2e3bb7323..d63b390eb1484 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -68,6 +68,23 @@ def auto_transfer_args(self, *args, **kwargs): def parameter_validation(fn: Callable) -> Callable: + """ + Decorator for `~pytorch_lightning.core.LightningModule.to` method. + Validates that the module parameter lengths match after moving to the device. It is useful + when tying weights on TPU's. + + Args: + fn: `.to` method + + Note: + TPU's require weights to be tied/shared after moving the module to the device. + Failure to do this results in the initialization of new weights which are not tied. + To overcome this issue, weights should be tied using the `on_post_move_to_device` model hook + which is called after the module has been moved to the device. + + See Also: + - `XLA Documentation `_ + """ @wraps(fn) def inner_f(self, *args, **kwargs): pre_param_count = len(list(self.parameters())) diff --git a/tests/backends/test_tpu_backend.py b/tests/backends/test_tpu_backend.py index aa9590316f06c..7c1b19504632d 100644 --- a/tests/backends/test_tpu_backend.py +++ b/tests/backends/test_tpu_backend.py @@ -14,13 +14,12 @@ import pytest import torch -from torch import nn from pytorch_lightning import Trainer -from tests.base import SimpleModule from pytorch_lightning.utilities.xla_device import XLADeviceUtils from tests.base.boring_model import BoringModel from tests.base.develop_utils import pl_multi_process_test +from tests.base.weight_sharing_module import WeightSharingModule @pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine") @@ -73,20 +72,6 @@ def test_weight_tying_warning(tmpdir, capsys=None): post moving to device. """ - class WeightSharingModule(SimpleModule): - def __init__(self): - super().__init__() - self.layer_1 = nn.Linear(32, 10, bias=False) - self.layer_2 = nn.Linear(10, 32, bias=False) - self.layer_3 = nn.Linear(32, 10, bias=False) - self.layer_3.weight = self.layer_1.weight - - def forward(self, x): - x = self.layer_1(x) - x = self.layer_2(x) - x = self.layer_3(x) - return x - model = WeightSharingModule() trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) @@ -103,24 +88,11 @@ def test_if_weights_tied(tmpdir, capsys=None): Ensure no warning for parameter mismatch is thrown. """ - class WeightSharingModule(SimpleModule): - def __init__(self): - super().__init__() - self.layer_1 = nn.Linear(32, 10, bias=False) - self.layer_2 = nn.Linear(10, 32, bias=False) - self.layer_3 = nn.Linear(32, 10, bias=False) - self.layer_3.weight = self.layer_1.weight - - def forward(self, x): - x = self.layer_1(x) - x = self.layer_2(x) - x = self.layer_3(x) - return x - + class Model(WeightSharingModule): def on_post_move_to_device(self): self.layer_3.weight = self.layer_1.weight - model = WeightSharingModule() + model = Model() trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) with pytest.warns(UserWarning) as warnings: diff --git a/tests/base/weight_sharing_module.py b/tests/base/weight_sharing_module.py new file mode 100644 index 0000000000000..924b5233aa7cb --- /dev/null +++ b/tests/base/weight_sharing_module.py @@ -0,0 +1,18 @@ +from torch import nn + +from tests.base import SimpleModule + + +class WeightSharingModule(SimpleModule): + def __init__(self): + super().__init__() + self.layer_1 = nn.Linear(32, 10, bias=False) + self.layer_2 = nn.Linear(10, 32, bias=False) + self.layer_3 = nn.Linear(32, 10, bias=False) + self.layer_3.weight = self.layer_1.weight + + def forward(self, x): + x = self.layer_1(x) + x = self.layer_2(x) + x = self.layer_3(x) + return x