Skip to content

Commit

Permalink
Lightning Lite core and tests (#10175)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Oct 29, 2021
1 parent b4f43b1 commit 9d136a9
Show file tree
Hide file tree
Showing 13 changed files with 1,398 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023))
* Updated precision attributes in `DeepSpeedPlugin` ([#10164](https://github.com/PyTorchLightning/pytorch-lightning/pull/10164))
* Added the ability to return a result from rank 0 in `DDPSpawnPlugin.spawn` ([#10162](https://github.com/PyTorchLightning/pytorch-lightning/pull/10162))
* Added `pytorch_lightning.lite` package ([#10175](https://github.com/PyTorchLightning/pytorch-lightning/pull/10175))


- Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972))
Expand Down
17 changes: 17 additions & 0 deletions pytorch_lightning/lite/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.lite.lite import LightningLite

__all__ = ["LightningLite"]
501 changes: 501 additions & 0 deletions pytorch_lightning/lite/lite.py

Large diffs are not rendered by default.

126 changes: 126 additions & 0 deletions pytorch_lightning/lite/wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Generator, Iterator, Optional, Union

import torch
from torch import nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.plugins import PrecisionPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device


def _do_nothing_closure() -> None:
return None


class _LiteOptimizer:
def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None:
"""LiteOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer
step calls to the accelerator/strategy plugin.
The underlying wrapped optimizer object can be accessed via the property :attr:`optimizer`.
Args:
optimizer: The optimizer to wrap
accelerator: Reference to the accelerator for handling the optimizer step
"""
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
# not want to call on destruction of the `_LiteOptimizer
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")}
self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
self._optimizer = optimizer
self._accelerator = accelerator

@property
def optimizer(self) -> Optimizer:
return self._optimizer

def step(self, closure: Optional[Callable] = None) -> None:
closure = closure or _do_nothing_closure
self._accelerator.optimizer_step(
self.optimizer,
opt_idx=0,
closure=closure,
model=self._accelerator.model,
)


class _LiteModule(nn.Module):
def __init__(self, module: nn.Module, precision_plugin: PrecisionPlugin) -> None:
"""The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast
automatically for the forward pass.
The underlying wrapped module can be accessed via the property :attr:`module`.
Args:
module: The module to wrap
precision_plugin: Reference to the precision plugin for handling precision context
"""
super().__init__()
self._module = module
self._precision_plugin = precision_plugin

@property
def module(self) -> nn.Module:
return self._module

def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Casts all inputs to the right precision and handles autocast for operations in the module forward
method."""
precision = self._precision_plugin.precision
precision_to_type = {
"bf16": torch.bfloat16,
16: torch.float16,
32: torch.float32,
64: torch.float64,
}
# TODO (@awaelchli): let the precision plugin handle the conversion
to_type = precision_to_type[precision]
args, kwargs = apply_to_collection([args, kwargs], function=lambda t: t.to(to_type), dtype=Tensor)

with self._precision_plugin.forward_context():
output = self.module(*args, **kwargs)

output = apply_to_collection(output, function=lambda t: t.to(torch.get_default_dtype()), dtype=Tensor)
return output


class _LiteDataLoader(DataLoader):
def __init__(self, device: Optional[torch.device] = None, **dl_kwargs: Any) -> None:
"""The LiteDataLoader is an extension of the PyTorch :class:`~torch.utils.data.DataLoader` that adds
additional features such as moving the data to the device automatically.
Args:
device: The device to which the data should be moved. By default the device is `None` and no data
transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`).
**dl_kwargs: Accepts all arguments that the PyTorch :class:`~torch.utils.data.DataLoader` accepts.
"""
super().__init__(**dl_kwargs)
self._device = device

@property
def device(self) -> Optional[torch.device]:
return self._device

def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
iterator = super().__iter__()
if self._device is None:
return iterator

for item in iterator:
yield move_data_to_device(item, self._device)
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ def precision(self) -> Union[str, int]:

@property
def amp_level(self) -> Optional[str]:
return self._amp_level or self.lightning_module.trainer._accelerator_connector.amp_level
if self._amp_type == AMPType.APEX:
return self._amp_level or self.lightning_module.trainer._accelerator_connector.amp_level

@property
def amp_type(self) -> Optional[str]:
Expand Down
11 changes: 7 additions & 4 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
" in the `DataLoader` init to improve performance."
)

def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None:
@staticmethod
def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank)
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)

def _requires_distributed_sampler(self, dataloader) -> bool:
return (
Expand Down Expand Up @@ -336,7 +337,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train_dataloader")

# add worker_init_fn for correct seeding in worker processes
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)
apply_to_collection(self.train_dataloader, DataLoader, self._auto_add_worker_init_fn, rank=self.global_rank)

# add collate_fn to collect metadata for fault tolerant training
if _fault_tolerant_training():
Expand Down Expand Up @@ -443,7 +444,9 @@ def _reset_eval_dataloader(
dataloaders = [self.prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None]

# add worker_init_fn for correct seeding in worker processes
apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn)
apply_to_collection(
dataloaders, dtype=DataLoader, function=self._auto_add_worker_init_fn, rank=self.global_rank
)

loader_num_batches = []

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def __init__(
if gradient_clip_algorithm is not None
else gradient_clip_algorithm
)
self.track_grad_norm = float(track_grad_norm)
self.track_grad_norm: float = float(track_grad_norm)

self._detect_anomaly: bool = detect_anomaly
self._setup_on_init(num_sanity_val_steps)
Expand Down
2 changes: 1 addition & 1 deletion tests/helpers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def run_model_test(
assert trainer.state.finished, f"Training failed with {trainer.state}"
# Check that the model is actually changed post-training
change_ratio = torch.norm(initial_values - post_train_values)
assert change_ratio > 0.1, f"the model is changed of {change_ratio}"
assert change_ratio > 0.03, f"the model is changed of {change_ratio}"

# test model loading
pretrained_model = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model))
Expand Down
Empty file added tests/lite/__init__.py
Empty file.
Loading

0 comments on commit 9d136a9

Please sign in to comment.