Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connect the model to the training type plugin at the start of run #8536

Merged
merged 21 commits into from
Aug 4, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Replace `iteration_count` and other index attributes in the loops with progress dataclasses ([#8477](https://github.com/PyTorchLightning/pytorch-lightning/pull/8477))


- The `trainer.lightning_module` reference is now properly set at the very beginning of the run ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))


- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352)))


- The data-loading `Trainer` functions' `model` argument is now optional ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))
carmocca marked this conversation as resolved.
Show resolved Hide resolved


- Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))

Expand Down Expand Up @@ -71,7 +76,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- The accelerator and training type plugin `setup` hooks no longer have a `model` argument ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))
carmocca marked this conversation as resolved.
Show resolved Hide resolved


-
Expand Down
13 changes: 6 additions & 7 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,14 @@ def setup_environment(self) -> None:
"""
self.training_type_plugin.setup_environment()

def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
def setup(self, trainer: "pl.Trainer") -> None:
"""
Setup plugins for the trainer fit and creates optimizers.

Args:
trainer: the trainer instance
model: the LightningModule
"""
self.setup_training_type_plugin(model)
self.setup_training_type_plugin()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small note: This could be an issue if we ever decide to expose the Accelerator API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why an issue? There's the connect hook to connect the model already.

self.model should be available for the plugin when setup is called

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@carmocca is this an invariant? should the accelerator assert that the model is available before calling setup training type plugin?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ananthsub We could

carmocca marked this conversation as resolved.
Show resolved Hide resolved
if not self.training_type_plugin.setup_optimizers_in_pre_dispatch:
self.setup_optimizers(trainer)
self.setup_precision_plugin()
Expand Down Expand Up @@ -334,9 +333,9 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies

def setup_training_type_plugin(self, model: "pl.LightningModule") -> None:
def setup_training_type_plugin(self) -> None:
"""Attaches the training type plugin to the accelerator."""
self.training_type_plugin.setup(model)
self.training_type_plugin.setup()

def setup_precision_plugin(self) -> None:
"""Attaches the precision plugin to the accelerator"""
Expand Down Expand Up @@ -449,7 +448,7 @@ def model_sharded_context(self) -> Generator[None, None, None]:
yield

# todo: remove in v1.5
def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: "pl.LightningModule") -> None:
def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: "pl.LightningModule") -> None: # noqa
"""
Attaches the training type plugin to the accelerator.
Also transfers ownership of the model to this plugin
Expand All @@ -460,7 +459,7 @@ def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: "pl.Li
rank_zero_warn(
"Accelerator method `connect_training_type_plugin` was deprecated in v1.3. It will be removed in v1.5."
)
self.setup_training_type_plugin(model)
self.setup_training_type_plugin()

# todo: remove in v1.5
def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class CPUAccelerator(Accelerator):
"""Accelerator for CPU devices."""

def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
def setup(self, trainer: "pl.Trainer") -> None:
"""
Raises:
MisconfigurationException:
Expand All @@ -36,4 +36,4 @@ def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
if "cpu" not in str(self.root_device):
raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead.")

return super().setup(trainer, model)
return super().setup(trainer)
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def setup_environment(self) -> None:
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
torch.cuda.set_device(self.root_device)

def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
def setup(self, trainer: "pl.Trainer") -> None:
"""
Raises:
MisconfigurationException:
If the selected device is not GPU.
"""
self.set_nvidia_flags(trainer.local_rank)
return super().setup(trainer, model)
return super().setup(trainer)

def on_train_start(self) -> None:
# clear cache before training
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
class TPUAccelerator(Accelerator):
"""Accelerator for TPU devices."""

def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
def setup(self, trainer: "pl.Trainer") -> None:
"""
Raises:
MisconfigurationException:
Expand All @@ -45,7 +45,7 @@ def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:

if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
return super().setup(trainer, model)
return super().setup(trainer)

def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,10 @@ def get_max_batches(self) -> List[Union[int, float]]:

def reload_evaluation_dataloaders(self) -> None:
"""Reloads dataloaders if necessary"""
model = self.trainer.lightning_module
if self.trainer.testing:
self.trainer.reset_test_dataloader(model)
self.trainer.reset_test_dataloader()
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch:
self.trainer.reset_val_dataloader(model)
self.trainer.reset_val_dataloader()

def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_start`` hooks"""
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/ddp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def global_rank(self) -> int:
def world_size(self) -> int:
return self.num_nodes

def setup(self, model):
self._model = model
def setup(self) -> None:
# set the task idx
self.task_idx = self.cluster_environment.local_rank()
# the difference to DDP is that we don't call children processes here
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def distributed_sampler_kwargs(self):
def _is_single_process_single_device(self):
return True

def setup(self, model):
def setup(self) -> None:
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
# pass in a state q
smp = mp.get_context("spawn")
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def node_rank(self) -> int:
def world_size(self) -> int:
return 1

def setup(self, model):
def setup(self) -> None:
# model needs to be moved to the device before it is wrapped
model.to(self.root_device)
self._model = DataParallel(LightningParallelModule(model), self.parallel_devices)
self.model_to_device()
self._model = DataParallel(LightningParallelModule(self._model), self.parallel_devices)

def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION:
"""
Expand All @@ -76,9 +76,8 @@ def mean(t: torch.Tensor) -> torch.Tensor:
def root_device(self):
return self.parallel_devices[0]

def model_to_device(self):
# no need to do anything when model is wrapped in torch.nn.DataParallel
pass
def model_to_device(self) -> None:
self._model.to(self.root_device)

def barrier(self, *args, **kwargs):
pass
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs

def setup(self, model):
self._model = model
def setup(self) -> None:
self.model_to_device()

def pre_dispatch(self):
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ def root_device(self) -> torch.device:
def model_to_device(self) -> None:
self._model.to(self.root_device)

def setup(self, model: torch.nn.Module) -> torch.nn.Module:
def setup(self) -> None:
self.model_to_device()
return self.model

@property
def is_global_zero(self) -> bool:
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,8 @@ def pre_dispatch(self):
if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)

def setup(self, model: Module) -> Module:
def setup(self) -> None:
self.create_mp_queue()
return self.model

def create_mp_queue(self):
self.start_method = "fork"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def setup_environment(self) -> None:
which allows the user to access the accelerator environment before setup is complete.
"""

def setup(self, model: Module) -> None:
def setup(self) -> None:
"""Called by the accelerator to finish setup."""

@property
Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,20 @@ class TrainerCallbackHookMixin(ABC):
callbacks: List[Callback] = []
lightning_module: "pl.LightningModule"

def on_before_accelerator_backend_setup(self, model: "pl.LightningModule") -> None:
def on_before_accelerator_backend_setup(self) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
callback.on_before_accelerator_backend_setup(self, model)
callback.on_before_accelerator_backend_setup(self, self.lightning_module)

def configure_sharded_model(self, model: "pl.LightningModule") -> None:
def on_configure_sharded_model(self) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
callback.on_configure_sharded_model(self, model)
callback.on_configure_sharded_model(self, self.lightning_module)

def setup(self, model: "pl.LightningModule", stage: Optional[str]) -> None:
def setup(self, stage: Optional[str]) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
callback.setup(self, model, stage=stage)
callback.setup(self, self.lightning_module, stage=stage)

def teardown(self, stage: Optional[str] = None) -> None:
"""Called at the end of fit (train + validate), validate, test, or predict, or tune."""
Expand Down
17 changes: 5 additions & 12 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from datetime import timedelta
from typing import Dict, List, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities import rank_zero_info
Expand Down Expand Up @@ -132,25 +131,19 @@ def attach_model_logging_functions(self, model):
callback.log = model.log
callback.log_dict = model.log_dict

@staticmethod
def _attach_model_callbacks(model: "pl.LightningModule", trainer) -> None:
def _attach_model_callbacks(self) -> None:
"""
Attaches the callbacks defined in the model.
If a callback returned by the model's configure_callback method has the same type as one or several
callbacks already present in the trainer callbacks list, it will replace them.
In addition, all :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks
will be pushed to the end of the list, ensuring they run last.

Args:
model: A model which may or may not define new callbacks in
:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_callbacks`.
trainer: The trainer on which the callbacks get attached/merged.
"""
model_callbacks = model.configure_callbacks()
model_callbacks = self.trainer.model.configure_callbacks()
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if not model_callbacks:
return
model_callback_types = {type(c) for c in model_callbacks}
trainer_callback_types = {type(c) for c in trainer.callbacks}
trainer_callback_types = {type(c) for c in self.trainer.callbacks}
override_types = model_callback_types.intersection(trainer_callback_types)
if override_types:
rank_zero_info(
Expand All @@ -159,11 +152,11 @@ def _attach_model_callbacks(model: "pl.LightningModule", trainer) -> None:
f" {', '.join(sorted(t.__name__ for t in override_types))}"
)
# remove all callbacks with a type that occurs in model callbacks
all_callbacks = [c for c in trainer.callbacks if type(c) not in override_types]
all_callbacks = [c for c in self.trainer.callbacks if type(c) not in override_types]
all_callbacks.extend(model_callbacks)
all_callbacks = CallbackConnector._reorder_callbacks(all_callbacks)
# TODO: connectors refactor: move callbacks list to connector and do not write Trainer state
trainer.callbacks = all_callbacks
self.trainer.callbacks = all_callbacks

@staticmethod
def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ def get_profiled_train_dataloader(self, train_dataloader):
)
return profiled_dl

def prepare_data(self, model):
def prepare_data(self) -> None:
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
if self.can_prepare_data():
if self.trainer.datamodule is not None:
self.trainer.datamodule.prepare_data()
model.prepare_data()
self.trainer.lightning_module.prepare_data()
self.trainer._is_data_prepared = True

def can_prepare_data(self):
Expand Down
Loading