Skip to content

Commit

Permalink
Move parameter validation specific to TPU Training plugins (#7415)
Browse files Browse the repository at this point in the history
* Move parameter validation specific to TPU Training plugins

* update docstring
  • Loading branch information
kaushikb11 authored May 24, 2021
1 parent fa41c58 commit 3f460b1
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 29 deletions.
9 changes: 4 additions & 5 deletions pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,11 @@ def auto_transfer_args(self, *args, **kwargs):

def parameter_validation(fn: Callable) -> Callable:
"""
Decorator for :meth:`~pytorch_lightning.core.LightningModule.to` method.
Validates that the module parameter lengths match after moving to the device. It is useful
when tying weights on TPU's.
Args:
fn: ``.to`` method
fn: ``model_to_device`` method
Note:
TPU's require weights to be tied/shared after moving the module to the device.
Expand All @@ -90,10 +89,10 @@ def parameter_validation(fn: Callable) -> Callable:

@wraps(fn)
def inner_fn(self, *args, **kwargs):
pre_layer_count = len(list(self.parameters()))
pre_layer_count = len(list(self.model.parameters()))
module = fn(self, *args, **kwargs)
self.on_post_move_to_device()
post_layer_count = len(list(self.parameters()))
self.model.on_post_move_to_device()
post_layer_count = len(list(self.model.parameters()))

if not pre_layer_count == post_layer_count:
rank_zero_warn(
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch

from pytorch_lightning.core.decorators import parameter_validation
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.apply_func import move_data_to_device
Expand All @@ -39,6 +40,7 @@ def __init__(self, device: int, debug: bool = False):
def is_distributed(self) -> bool:
return False

@parameter_validation
def model_to_device(self) -> None:
self.model.to(self.root_device)

Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.core.decorators import parameter_validation
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
Expand Down Expand Up @@ -171,6 +172,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
if self.local_rank == 0:
time.sleep(2)

@parameter_validation
def model_to_device(self) -> None:
self.model = self.wrapped_model.to(self.root_device)

Expand Down
6 changes: 0 additions & 6 deletions pytorch_lightning/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import torch
from torch.nn import Module

from pytorch_lightning.core.decorators import parameter_validation


class DeviceDtypeModuleMixin(Module):
__jit_unused_properties__ = ['device', 'dtype']
Expand Down Expand Up @@ -47,7 +45,6 @@ def device(self) -> Union[str, torch.device]:

return device

@parameter_validation
def to(self, *args, **kwargs) -> Module:
"""Moves and/or casts the parameters and buffers.
Expand Down Expand Up @@ -84,9 +81,6 @@ def to(self, *args, **kwargs) -> Module:
... def __init__(self, weight: torch.Tensor):
... super().__init__()
... self.register_buffer('weight', weight)
...
... def on_post_move_to_device(self):
... pass
>>> _ = torch.manual_seed(0)
>>> module = ExampleModule(torch.rand(3, 4))
>>> module.weight #doctest: +ELLIPSIS
Expand Down
32 changes: 14 additions & 18 deletions tests/accelerators/test_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,25 +95,21 @@ def test_weight_tying_warning(tmpdir, capsys=None):
trainer.fit(model)


# @RunIf(tpu=True)
# @pl_multi_process_test
# def test_if_weights_tied(tmpdir, capsys=None):
# """
# Test if weights are properly tied on `on_post_move_to_device`.
# Ensure no warning for parameter mismatch is thrown.
# """

# # TODO (kaushikb11): Add `parameter_validation` specific to TPU Accelerators
# class Model(WeightSharingModule):
@RunIf(tpu=True)
@pl_multi_process_test
def test_if_weights_tied(tmpdir, capsys=None):
"""
Test if weights are properly tied on `on_post_move_to_device`.
Ensure no warning for parameter mismatch is thrown.
"""

# def on_post_move_to_device(self):
# self.layer_3.weight = self.layer_1.weight
class Model(WeightSharingModule):

# model = Model()
# trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)
def on_post_move_to_device(self):
self.layer_3.weight = self.layer_1.weight

# with pytest.warns(UserWarning) as warnings:
# trainer.fit(model)
model = Model()
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)

# assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list))
# assert len(trainer.test(model)) == 1
with pytest.warns(UserWarning, match="The model layers do not match"):
trainer.fit(model)

0 comments on commit 3f460b1

Please sign in to comment.