Skip to content

Commit

Permalink
docs and refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
lezwon committed Jan 10, 2021
1 parent 46bc680 commit eebc4de
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 32 deletions.
54 changes: 53 additions & 1 deletion docs/source/tpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,59 @@ set the 16-bit flag.
Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_.

----------------

-----------------

Weight Sharing/Tying
-----------------------
Weight Tying/Sharing is a technique where in the module weights are shared among two or more layers.
This is a common method to reduce memory consumption and is utilized in many State of the Art
architectures today.

PyTorch XLA requires these weights to be tied/shared after moving the model
to the TPU device. To support this requirement Lightning provides a model hook which is
called after the model is moved to the device. Any weights that require to be tied should
be done in the `on_post_move_to_device` model hook. This will ensure that the weights
among the modules are shared and not copied.

PyTorch Lightning has an inbuilt check which verifies that the model parameter lengths
match once the model is moved to the device. If the lengths do not match Lightning
throws a warning message.

Example:

.. code-block:: python
import pytorch_lightning as pl
from torch import nn
class WeightSharingModule(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(32, 10, bias=False)
self.layer_2 = nn.Linear(10, 32, bias=False)
self.layer_3 = nn.Linear(32, 10, bias=False)
self.layer_3.weight = self.layer_1.weight # Weights will be copied on TPU
def forward(self, x):
x = self.layer_1(x)
x = self.layer_2(x)
x = self.layer_3(x)
return x
def on_post_move_to_device(self):
# Weights shared after the model has been moved to TPU Device
self.layer_3.weight = self.layer_1.weight
model = WeightSharingModule()
trainer = Trainer(max_epochs=1, tpu_cores=8)
result = trainer.fit(model)
See `XLA Documentation <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks>`_

-----------------------

About XLA
----------
Expand Down
17 changes: 17 additions & 0 deletions pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@ def auto_transfer_args(self, *args, **kwargs):


def parameter_validation(fn: Callable) -> Callable:
"""
Decorator for `~pytorch_lightning.core.LightningModule.to` method.
Validates that the module parameter lengths match after moving to the device. It is useful
when tying weights on TPU's.
Args:
fn: `.to` method
Note:
TPU's require weights to be tied/shared after moving the module to the device.
Failure to do this results in the initialization of new weights which are not tied.
To overcome this issue, weights should be tied using the `on_post_move_to_device` model hook
which is called after the module has been moved to the device.
See Also:
- `XLA Documentation <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks>`_
"""
@wraps(fn)
def inner_f(self, *args, **kwargs):
pre_param_count = len(list(self.parameters()))
Expand Down
34 changes: 3 additions & 31 deletions tests/backends/test_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@

import pytest
import torch
from torch import nn

from pytorch_lightning import Trainer
from tests.base import SimpleModule
from pytorch_lightning.utilities.xla_device import XLADeviceUtils
from tests.base.boring_model import BoringModel
from tests.base.develop_utils import pl_multi_process_test
from tests.base.weight_sharing_module import WeightSharingModule


@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine")
Expand Down Expand Up @@ -73,20 +72,6 @@ def test_weight_tying_warning(tmpdir, capsys=None):
post moving to device.
"""

class WeightSharingModule(SimpleModule):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(32, 10, bias=False)
self.layer_2 = nn.Linear(10, 32, bias=False)
self.layer_3 = nn.Linear(32, 10, bias=False)
self.layer_3.weight = self.layer_1.weight

def forward(self, x):
x = self.layer_1(x)
x = self.layer_2(x)
x = self.layer_3(x)
return x

model = WeightSharingModule()
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)

Expand All @@ -103,24 +88,11 @@ def test_if_weights_tied(tmpdir, capsys=None):
Ensure no warning for parameter mismatch is thrown.
"""

class WeightSharingModule(SimpleModule):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(32, 10, bias=False)
self.layer_2 = nn.Linear(10, 32, bias=False)
self.layer_3 = nn.Linear(32, 10, bias=False)
self.layer_3.weight = self.layer_1.weight

def forward(self, x):
x = self.layer_1(x)
x = self.layer_2(x)
x = self.layer_3(x)
return x

class Model(WeightSharingModule):
def on_post_move_to_device(self):
self.layer_3.weight = self.layer_1.weight

model = WeightSharingModule()
model = Model()
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)

with pytest.warns(UserWarning) as warnings:
Expand Down
18 changes: 18 additions & 0 deletions tests/base/weight_sharing_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from torch import nn

from tests.base import SimpleModule


class WeightSharingModule(SimpleModule):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(32, 10, bias=False)
self.layer_2 = nn.Linear(10, 32, bias=False)
self.layer_3 = nn.Linear(32, 10, bias=False)
self.layer_3.weight = self.layer_1.weight

def forward(self, x):
x = self.layer_1(x)
x = self.layer_2(x)
x = self.layer_3(x)
return x

0 comments on commit eebc4de

Please sign in to comment.