Skip to content

Commit

Permalink
Connect the model to the training type plugin at the start of run (#8536
Browse files Browse the repository at this point in the history
)
  • Loading branch information
carmocca authored Aug 4, 2021
1 parent 49df107 commit ed13040
Show file tree
Hide file tree
Showing 25 changed files with 132 additions and 128 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,22 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Replace `iteration_count` and other index attributes in the loops with progress dataclasses ([#8477](https://github.com/PyTorchLightning/pytorch-lightning/pull/8477))


- The `trainer.lightning_module` reference is now properly set at the very beginning of the run ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))


- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352)))


- The `Trainer` functions `reset_{train,val,test,predict}_dataloader`, `reset_train_val_dataloaders`, and `request_dataloader` `model` argument is now optional ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))


- Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


- Improved string conversion for `ResultCollection` ([#8622](https://github.com/PyTorchLightning/pytorch-lightning/pull/8622))


-
- The accelerator and training type plugin `setup` hooks no longer have a `model` argument ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))

### Deprecated

Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,14 @@ def setup_environment(self) -> None:
"""
self.training_type_plugin.setup_environment()

def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
def setup(self, trainer: "pl.Trainer") -> None:
"""
Setup plugins for the trainer fit and creates optimizers.
Args:
trainer: the trainer instance
model: the LightningModule
"""
self.setup_training_type_plugin(model)
self.setup_training_type_plugin()
if not self.training_type_plugin.setup_optimizers_in_pre_dispatch:
self.setup_optimizers(trainer)
self.setup_precision_plugin()
Expand Down Expand Up @@ -334,9 +333,9 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies

def setup_training_type_plugin(self, model: "pl.LightningModule") -> None:
def setup_training_type_plugin(self) -> None:
"""Attaches the training type plugin to the accelerator."""
self.training_type_plugin.setup(model)
self.training_type_plugin.setup()

def setup_precision_plugin(self) -> None:
"""Attaches the precision plugin to the accelerator"""
Expand Down Expand Up @@ -460,7 +459,7 @@ def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: "pl.Li
rank_zero_warn(
"Accelerator method `connect_training_type_plugin` was deprecated in v1.3. It will be removed in v1.5."
)
self.setup_training_type_plugin(model)
self.setup_training_type_plugin()

# todo: remove in v1.5
def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class CPUAccelerator(Accelerator):
"""Accelerator for CPU devices."""

def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
def setup(self, trainer: "pl.Trainer") -> None:
"""
Raises:
MisconfigurationException:
Expand All @@ -36,4 +36,4 @@ def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
if "cpu" not in str(self.root_device):
raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead.")

return super().setup(trainer, model)
return super().setup(trainer)
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def setup_environment(self) -> None:
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
torch.cuda.set_device(self.root_device)

def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
def setup(self, trainer: "pl.Trainer") -> None:
"""
Raises:
MisconfigurationException:
If the selected device is not GPU.
"""
self.set_nvidia_flags(trainer.local_rank)
return super().setup(trainer, model)
return super().setup(trainer)

def on_train_start(self) -> None:
# clear cache before training
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
class TPUAccelerator(Accelerator):
"""Accelerator for TPU devices."""

def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
def setup(self, trainer: "pl.Trainer") -> None:
"""
Raises:
MisconfigurationException:
Expand All @@ -45,7 +45,7 @@ def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:

if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
return super().setup(trainer, model)
return super().setup(trainer)

def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,10 @@ def get_max_batches(self) -> List[Union[int, float]]:

def reload_evaluation_dataloaders(self) -> None:
"""Reloads dataloaders if necessary"""
model = self.trainer.lightning_module
if self.trainer.testing:
self.trainer.reset_test_dataloader(model)
self.trainer.reset_test_dataloader()
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch:
self.trainer.reset_val_dataloader(model)
self.trainer.reset_val_dataloader()

def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_start`` hooks"""
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/ddp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def global_rank(self) -> int:
def world_size(self) -> int:
return self.num_nodes

def setup(self, model):
self._model = model
def setup(self) -> None:
# set the task idx
self.task_idx = self.cluster_environment.local_rank()
# the difference to DDP is that we don't call children processes here
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def distributed_sampler_kwargs(self):
def _is_single_process_single_device(self):
return True

def setup(self, model):
def setup(self) -> None:
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
# pass in a state q
smp = mp.get_context("spawn")
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def node_rank(self) -> int:
def world_size(self) -> int:
return 1

def setup(self, model):
def setup(self) -> None:
# model needs to be moved to the device before it is wrapped
model.to(self.root_device)
self._model = DataParallel(LightningParallelModule(model), self.parallel_devices)
self.model_to_device()
self._model = DataParallel(LightningParallelModule(self._model), self.parallel_devices)

def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION:
"""
Expand All @@ -76,9 +76,8 @@ def mean(t: torch.Tensor) -> torch.Tensor:
def root_device(self):
return self.parallel_devices[0]

def model_to_device(self):
# no need to do anything when model is wrapped in torch.nn.DataParallel
pass
def model_to_device(self) -> None:
self._model.to(self.root_device)

def barrier(self, *args, **kwargs):
pass
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/plugins/training_type/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import torch
from torch import Tensor
from torch.nn import Module

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
Expand Down Expand Up @@ -138,9 +137,11 @@ def wrap_policy(*args, **kwargs):
):
yield

def connect(self, model: Module) -> None:
super().connect(model)
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
def setup_environment(self) -> None:
super().setup_environment()
model_call_configure_sharded_model_hook = getattr(
self.lightning_module, "call_configure_sharded_model_hook", False
)
if not model_call_configure_sharded_model_hook:
# if model has not called configure sharded model, we reset
# the training type plugin's call_configure_sharded_model_hook
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs

def setup(self, model):
self._model = model
def setup(self) -> None:
self.model_to_device()

def pre_dispatch(self):
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ def root_device(self) -> torch.device:
def model_to_device(self) -> None:
self._model.to(self.root_device)

def setup(self, model: torch.nn.Module) -> torch.nn.Module:
def setup(self) -> None:
self.model_to_device()
return self.model

@property
def is_global_zero(self) -> bool:
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,8 @@ def pre_dispatch(self):
if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)

def setup(self, model: Module) -> Module:
def setup(self) -> None:
self.create_mp_queue()
return self.model

def create_mp_queue(self):
self.start_method = "fork"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def setup_environment(self) -> None:
which allows the user to access the accelerator environment before setup is complete.
"""

def setup(self, model: Module) -> None:
def setup(self) -> None:
"""Called by the accelerator to finish setup."""

@property
Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,20 @@ class TrainerCallbackHookMixin(ABC):
callbacks: List[Callback] = []
lightning_module: "pl.LightningModule"

def on_before_accelerator_backend_setup(self, model: "pl.LightningModule") -> None:
def on_before_accelerator_backend_setup(self) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
callback.on_before_accelerator_backend_setup(self, model)
callback.on_before_accelerator_backend_setup(self, self.lightning_module)

def configure_sharded_model(self, model: "pl.LightningModule") -> None:
def on_configure_sharded_model(self) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
callback.on_configure_sharded_model(self, model)
callback.on_configure_sharded_model(self, self.lightning_module)

def setup(self, model: "pl.LightningModule", stage: Optional[str]) -> None:
def setup(self, stage: Optional[str]) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
callback.setup(self, model, stage=stage)
callback.setup(self, self.lightning_module, stage=stage)

def teardown(self, stage: Optional[str] = None) -> None:
"""Called at the end of fit (train + validate), validate, test, or predict, or tune."""
Expand Down
17 changes: 5 additions & 12 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from datetime import timedelta
from typing import Dict, List, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities import rank_zero_info
Expand Down Expand Up @@ -132,25 +131,19 @@ def attach_model_logging_functions(self, model):
callback.log = model.log
callback.log_dict = model.log_dict

@staticmethod
def _attach_model_callbacks(model: "pl.LightningModule", trainer) -> None:
def _attach_model_callbacks(self) -> None:
"""
Attaches the callbacks defined in the model.
If a callback returned by the model's configure_callback method has the same type as one or several
callbacks already present in the trainer callbacks list, it will replace them.
In addition, all :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks
will be pushed to the end of the list, ensuring they run last.
Args:
model: A model which may or may not define new callbacks in
:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_callbacks`.
trainer: The trainer on which the callbacks get attached/merged.
"""
model_callbacks = model.configure_callbacks()
model_callbacks = self.trainer.model.configure_callbacks()
if not model_callbacks:
return
model_callback_types = {type(c) for c in model_callbacks}
trainer_callback_types = {type(c) for c in trainer.callbacks}
trainer_callback_types = {type(c) for c in self.trainer.callbacks}
override_types = model_callback_types.intersection(trainer_callback_types)
if override_types:
rank_zero_info(
Expand All @@ -159,11 +152,11 @@ def _attach_model_callbacks(model: "pl.LightningModule", trainer) -> None:
f" {', '.join(sorted(t.__name__ for t in override_types))}"
)
# remove all callbacks with a type that occurs in model callbacks
all_callbacks = [c for c in trainer.callbacks if type(c) not in override_types]
all_callbacks = [c for c in self.trainer.callbacks if type(c) not in override_types]
all_callbacks.extend(model_callbacks)
all_callbacks = CallbackConnector._reorder_callbacks(all_callbacks)
# TODO: connectors refactor: move callbacks list to connector and do not write Trainer state
trainer.callbacks = all_callbacks
self.trainer.callbacks = all_callbacks

@staticmethod
def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ def get_profiled_train_dataloader(self, train_dataloader):
)
return profiled_dl

def prepare_data(self, model):
def prepare_data(self) -> None:
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
if self.can_prepare_data():
if self.trainer.datamodule is not None:
self.trainer.datamodule.prepare_data()
model.prepare_data()
self.trainer.lightning_module.prepare_data()
self.trainer._is_data_prepared = True

def can_prepare_data(self):
Expand Down
Loading

0 comments on commit ed13040

Please sign in to comment.