Skip to content

Commit

Permalink
Merge branch 'master' into feat/cli-restructuring
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Aug 4, 2021
2 parents ed326ef + 963c267 commit f4e5bee
Show file tree
Hide file tree
Showing 28 changed files with 188 additions and 144 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Replace `iteration_count` and other index attributes in the loops with progress dataclasses ([#8477](https://github.com/PyTorchLightning/pytorch-lightning/pull/8477))


- The `trainer.lightning_module` reference is now properly set at the very beginning of the run ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))


- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352))


- The `Trainer` functions `reset_{train,val,test,predict}_dataloader`, `reset_train_val_dataloaders`, and `request_dataloader` `model` argument is now optional ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))


- Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


Expand All @@ -51,6 +57,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* `LightningCLI.instantiate_trainer` now takes a config and a list of callbacks. ([#8721](https://github.com/PyTorchLightning/pytorch-lightning/pull/8721))
* Split `LightningCLI.add_core_arguments_to_parser` into `LightningCLI.add_default_arguments_to_parser` + `LightningCLI.add_core_arguments_to_parser`. ([#8721](https://github.com/PyTorchLightning/pytorch-lightning/pull/8721))


- The accelerator and training type plugin `setup` hooks no longer have a `model` argument ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))

### Deprecated

- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
Expand Down Expand Up @@ -112,6 +121,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
[#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627))


- Fixed recursive call for `apply_to_collection(include_none=False)` ([#8719](https://github.com/PyTorchLightning/pytorch-lightning/pull/8719))



## [1.4.0] - 2021-07-27

### Added
Expand Down
43 changes: 32 additions & 11 deletions docs/source/advanced/advanced_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,13 @@ To reach larger parameter sizes and be memory efficient, we have to shard parame
This means you must create a single model that is treated as a ``torch.nn.Module`` within the ``LightningModule``.
This is a limitation of Fully Sharded Training that will be resolved in the future.

Wrap the Model
""""""""""""""
Enabling Module Sharding for Maximum Memory Effeciency
""""""""""""""""""""""""""""""""""""""""""""""""""""""

To activate parameter sharding, you must wrap your model using provided ``wrap`` or ``auto_wrap`` functions as described below. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` and ``auto_wrap`` parameters are passed correctly.

When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other plugins.

This is a requirement for really large models and also saves on instantiation time as modules are sharded instantly, rather than after the entire model is created in memory.

``auto_wrap`` will recursively wrap `torch.nn.Modules` within the ``LightningModule`` with nested Fully Sharded Wrappers,
signalling that we'd like to partition these modules across data parallel devices, discarding the full weights when not required (information `here <https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html>`__).

Expand All @@ -129,22 +127,28 @@ Below is an example of using both ``wrap`` and ``auto_wrap`` to create your mode
class MyModel(pl.LightningModule):
...
def __init__(self):
super().__init__()
self.linear_layer = nn.Linear(32, 32)
self.block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
self.final_block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
def configure_sharded_model(self):
# Created within sharded model context, modules are instantly sharded across processes
# as soon as they are wrapped with ``wrap`` or ``auto_wrap``
# modules are sharded across processes
# as soon as they are wrapped with ``wrap`` or ``auto_wrap``.
# During the forward/backward passes, weights get synced across processes
# and de-allocated once computation is complete, saving memory.
# Wraps the layer in a Fully Sharded Wrapper automatically
linear_layer = wrap(nn.Linear(32, 32))
linear_layer = wrap(self.linear_layer)
# Wraps the module recursively
# based on a minimum number of parameters (default 100M parameters)
block = auto_wrap(nn.Sequential(nn.Linear(32, 32), nn.ReLU()))
block = auto_wrap(self.block)
# For best memory efficiency,
# add fairscale activation checkpointing
final_block = auto_wrap(checkpoint_wrapper(nn.Sequential(nn.Linear(32, 32), nn.ReLU())))
# add FairScale activation checkpointing
final_block = auto_wrap(checkpoint_wrapper(self.final_block))
self.model = nn.Sequential(linear_layer, nn.ReLU(), block, final_block)
def configure_optimizers(self):
Expand Down Expand Up @@ -359,6 +363,23 @@ Also please have a look at our :ref:`deepspeed-zero-stage-3-tips` which contains
trainer.predict()
You can also use the Lightning Trainer to run predict or evaluate with DeepSpeed once the model has been trained.

.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin
class MyModel(pl.LightningModule):
...
model = MyModel()
trainer = Trainer(gpus=4, plugins="deepspeed_stage_3", precision=16)
trainer.test(ckpt_path="my_saved_deepspeed_checkpoint.ckpt")
Shard Model Instantly to Reduce Initialization Time/Memory
""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,14 @@ def setup_environment(self) -> None:
"""
self.training_type_plugin.setup_environment()

def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
def setup(self, trainer: "pl.Trainer") -> None:
"""
Setup plugins for the trainer fit and creates optimizers.
Args:
trainer: the trainer instance
model: the LightningModule
"""
self.setup_training_type_plugin(model)
self.setup_training_type_plugin()
if not self.training_type_plugin.setup_optimizers_in_pre_dispatch:
self.setup_optimizers(trainer)
self.setup_precision_plugin()
Expand Down Expand Up @@ -334,9 +333,9 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies

def setup_training_type_plugin(self, model: "pl.LightningModule") -> None:
def setup_training_type_plugin(self) -> None:
"""Attaches the training type plugin to the accelerator."""
self.training_type_plugin.setup(model)
self.training_type_plugin.setup()

def setup_precision_plugin(self) -> None:
"""Attaches the precision plugin to the accelerator"""
Expand Down Expand Up @@ -460,7 +459,7 @@ def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: "pl.Li
rank_zero_warn(
"Accelerator method `connect_training_type_plugin` was deprecated in v1.3. It will be removed in v1.5."
)
self.setup_training_type_plugin(model)
self.setup_training_type_plugin()

# todo: remove in v1.5
def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class CPUAccelerator(Accelerator):
"""Accelerator for CPU devices."""

def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
def setup(self, trainer: "pl.Trainer") -> None:
"""
Raises:
MisconfigurationException:
Expand All @@ -36,4 +36,4 @@ def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
if "cpu" not in str(self.root_device):
raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead.")

return super().setup(trainer, model)
return super().setup(trainer)
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def setup_environment(self) -> None:
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
torch.cuda.set_device(self.root_device)

def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
def setup(self, trainer: "pl.Trainer") -> None:
"""
Raises:
MisconfigurationException:
If the selected device is not GPU.
"""
self.set_nvidia_flags(trainer.local_rank)
return super().setup(trainer, model)
return super().setup(trainer)

def on_train_start(self) -> None:
# clear cache before training
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
class TPUAccelerator(Accelerator):
"""Accelerator for TPU devices."""

def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
def setup(self, trainer: "pl.Trainer") -> None:
"""
Raises:
MisconfigurationException:
Expand All @@ -45,7 +45,7 @@ def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:

if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
return super().setup(trainer, model)
return super().setup(trainer)

def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,10 @@ def get_max_batches(self) -> List[Union[int, float]]:

def reload_evaluation_dataloaders(self) -> None:
"""Reloads dataloaders if necessary"""
model = self.trainer.lightning_module
if self.trainer.testing:
self.trainer.reset_test_dataloader(model)
self.trainer.reset_test_dataloader()
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch:
self.trainer.reset_val_dataloader(model)
self.trainer.reset_val_dataloader()

def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_start`` hooks"""
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/ddp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def global_rank(self) -> int:
def world_size(self) -> int:
return self.num_nodes

def setup(self, model):
self._model = model
def setup(self) -> None:
# set the task idx
self.task_idx = self.cluster_environment.local_rank()
# the difference to DDP is that we don't call children processes here
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def distributed_sampler_kwargs(self):
def _is_single_process_single_device(self):
return True

def setup(self, model):
def setup(self) -> None:
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
# pass in a state q
smp = mp.get_context("spawn")
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def node_rank(self) -> int:
def world_size(self) -> int:
return 1

def setup(self, model):
def setup(self) -> None:
# model needs to be moved to the device before it is wrapped
model.to(self.root_device)
self._model = DataParallel(LightningParallelModule(model), self.parallel_devices)
self.model_to_device()
self._model = DataParallel(LightningParallelModule(self._model), self.parallel_devices)

def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION:
"""
Expand All @@ -76,9 +76,8 @@ def mean(t: torch.Tensor) -> torch.Tensor:
def root_device(self):
return self.parallel_devices[0]

def model_to_device(self):
# no need to do anything when model is wrapped in torch.nn.DataParallel
pass
def model_to_device(self) -> None:
self._model.to(self.root_device)

def barrier(self, *args, **kwargs):
pass
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/plugins/training_type/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import torch
from torch import Tensor
from torch.nn import Module

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
Expand Down Expand Up @@ -138,9 +137,11 @@ def wrap_policy(*args, **kwargs):
):
yield

def connect(self, model: Module) -> None:
super().connect(model)
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
def setup_environment(self) -> None:
super().setup_environment()
model_call_configure_sharded_model_hook = getattr(
self.lightning_module, "call_configure_sharded_model_hook", False
)
if not model_call_configure_sharded_model_hook:
# if model has not called configure sharded model, we reset
# the training type plugin's call_configure_sharded_model_hook
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs

def setup(self, model):
self._model = model
def setup(self) -> None:
self.model_to_device()

def pre_dispatch(self):
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ def root_device(self) -> torch.device:
def model_to_device(self) -> None:
self._model.to(self.root_device)

def setup(self, model: torch.nn.Module) -> torch.nn.Module:
def setup(self) -> None:
self.model_to_device()
return self.model

@property
def is_global_zero(self) -> bool:
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,8 @@ def pre_dispatch(self):
if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)

def setup(self, model: Module) -> Module:
def setup(self) -> None:
self.create_mp_queue()
return self.model

def create_mp_queue(self):
self.start_method = "fork"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def setup_environment(self) -> None:
which allows the user to access the accelerator environment before setup is complete.
"""

def setup(self, model: Module) -> None:
def setup(self) -> None:
"""Called by the accelerator to finish setup."""

@property
Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,20 @@ class TrainerCallbackHookMixin(ABC):
callbacks: List[Callback] = []
lightning_module: "pl.LightningModule"

def on_before_accelerator_backend_setup(self, model: "pl.LightningModule") -> None:
def on_before_accelerator_backend_setup(self) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
callback.on_before_accelerator_backend_setup(self, model)
callback.on_before_accelerator_backend_setup(self, self.lightning_module)

def configure_sharded_model(self, model: "pl.LightningModule") -> None:
def on_configure_sharded_model(self) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
callback.on_configure_sharded_model(self, model)
callback.on_configure_sharded_model(self, self.lightning_module)

def setup(self, model: "pl.LightningModule", stage: Optional[str]) -> None:
def setup(self, stage: Optional[str]) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
callback.setup(self, model, stage=stage)
callback.setup(self, self.lightning_module, stage=stage)

def teardown(self, stage: Optional[str] = None) -> None:
"""Called at the end of fit (train + validate), validate, test, or predict, or tune."""
Expand Down
Loading

0 comments on commit f4e5bee

Please sign in to comment.