diff --git a/CHANGELOG.md b/CHANGELOG.md index a0dbf92661ffb..642e8dd25c436 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -219,6 +219,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023)) * Updated precision attributes in `DeepSpeedPlugin` ([#10164](https://github.com/PyTorchLightning/pytorch-lightning/pull/10164)) * Added the ability to return a result from rank 0 in `DDPSpawnPlugin.spawn` ([#10162](https://github.com/PyTorchLightning/pytorch-lightning/pull/10162)) + * Added `pytorch_lightning.lite` package ([#10175](https://github.com/PyTorchLightning/pytorch-lightning/pull/10175)) - Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972)) diff --git a/pytorch_lightning/lite/__init__.py b/pytorch_lightning/lite/__init__.py new file mode 100644 index 0000000000000..f4634fe54e548 --- /dev/null +++ b/pytorch_lightning/lite/__init__.py @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.lite.lite import LightningLite + +__all__ = ["LightningLite"] diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py new file mode 100644 index 0000000000000..7d0ff6a436b61 --- /dev/null +++ b/pytorch_lightning/lite/lite.py @@ -0,0 +1,501 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from abc import ABC, abstractmethod +from contextlib import contextmanager +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler + +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer +from pytorch_lightning.plugins import ( + DDPShardedPlugin, + DDPSpawnPlugin, + DeepSpeedPlugin, + PLUGIN_INPUT, + TPUSpawnPlugin, + TrainingTypePlugin, +) +from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector +from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin +from pytorch_lightning.trainer.trainer import Trainer +from pytorch_lightning.utilities import DeviceType, DistributedType, move_data_to_device +from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors +from pytorch_lightning.utilities.data import has_iterable_dataset +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.seed import seed_everything + + +class LightningLite(ABC): + """Lite accelerates your PyTorch training or inference code with minimal changes required. + + - Automatic placement of models and data onto the device. + - Automatic support for mixed and double precision (smaller memory footprint). + - Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies + (data-parallel training, sharded training, etc.). + - Automated spawning of processes, no launch utilities required. + - Multi-node support. + + Args: + accelerator: The hardware to run on. Possible choices are: ``"cpu"``, ``"gpu"``, ``"tpu"``, ``"auto"``. + strategy: Strategy for how to run across multiple devices. Possible choices are: + ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"ddp_sharded"``. + devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``. + The value applies per node. + num_nodes: Number of GPU nodes for distributed training. + precision: Double precision (``64``), full precision (``32``), half precision (``16``), + or bfloat16 precision (``"bf16"``). + plugins: One or several custom plugins + gpus: Provides the same function as the ``devices`` argument but implies ``accelerator="gpu"``. + tpu_cores: Provides the same function as the ``devices`` argument but implies ``accelerator="tpu"``. + """ + + def __init__( + self, + accelerator: Optional[Union[str, Accelerator]] = None, + strategy: Optional[Union[str, TrainingTypePlugin]] = None, + devices: Optional[Union[List[int], str, int]] = None, + num_nodes: int = 1, + precision: Union[int, str] = 32, + plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None, + gpus: Optional[Union[List[int], str, int]] = None, + tpu_cores: Optional[Union[List[int], str, int]] = None, + ) -> None: + self._check_accelerator_support(accelerator) + self._check_strategy_support(strategy) + gpu_ids, tpu_cores = Trainer._parse_devices(gpus=gpus, auto_select_gpus=False, tpu_cores=tpu_cores) + self._accelerator_connector = AcceleratorConnector( + num_processes=1, + devices=devices, + tpu_cores=tpu_cores, + ipus=None, + accelerator=accelerator, + strategy=strategy, + gpus=gpus, + gpu_ids=gpu_ids, + num_nodes=num_nodes, + sync_batchnorm=False, # TODO: add support? + benchmark=False, + replace_sampler_ddp=True, + deterministic=False, + precision=precision, + amp_type="native", + amp_level=None, + plugins=plugins, + ) + self._accelerator = self._accelerator_connector.accelerator + self._strategy = self._accelerator.training_type_plugin + self._precision_plugin = self._accelerator.precision_plugin + self._models_setup: int = 0 + + # wrap the run method so we can inject setup logic or spawn processes for the user + setattr(self, "run", partial(self._run_impl, self.run)) + + @property + def device(self) -> torch.device: + """The current device this process runs on. + + Use this to create tensors directly on the device if needed. + """ + return self._accelerator.root_device + + @property + def global_rank(self) -> int: + """The global index of the current process across all devices and nodes.""" + return getattr(self._strategy, "global_rank", 0) + + @property + def local_rank(self) -> int: + """The index of the current process among the processes running on the local node.""" + return getattr(self._strategy, "local_rank", 0) + + @property + def node_rank(self) -> int: + """The index of the current node.""" + return getattr(self._strategy, "node_rank", 0) + + @property + def world_size(self) -> int: + """The total number of processes running across all devices and nodes.""" + return getattr(self._strategy, "world_size", 1) + + @property + def is_global_zero(self) -> bool: + """Wether this rank is rank zero.""" + return self._strategy.is_global_zero + + @abstractmethod + def run(self) -> Any: + """All the code inside this run method gets accelerated by Lite. + + You can pass arbitrary arguments to this function when overriding it. + """ + + def setup( + self, + model: nn.Module, + *optimizers: Optimizer, + move_to_device: bool = True, + ) -> Any: # no specific return because the way we want our API to look does not play well with mypy + """Setup a model and its optimizers for accelerated training. + + Args: + model: A model to setup + *optimizers: The optimizer(s) to setup (no optimizers is also possible) + move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False`` + and alternatively use :meth:`to_device` manually. + + Returns: + The tuple of the wrapped model and list of optimizers, in the same order they were passed in. + """ + self._validate_setup(model, optimizers) + + if move_to_device: + model = self._move_model_to_device(model=model, optimizers=list(optimizers)) + + # Let accelerator/plugin wrap and connect the models and optimizers + model, optimizers = self._strategy._setup_model_and_optimizers(model, list(optimizers)) + model = _LiteModule(model, self._precision_plugin) + optimizers = [_LiteOptimizer(optimizer=optimizer, accelerator=self._accelerator) for optimizer in optimizers] + self._models_setup += 1 + if optimizers: + # join both types in a list for API convenience + return [model] + optimizers # type: ignore + return model + + def setup_dataloaders( + self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True + ) -> Union[DataLoader, List[DataLoader], Iterable]: + """Setup one or multiple dataloaders for accelerated training. If you need different settings for each + dataloader, call this method individually for each one. + + Args: + *dataloaders: A single dataloader or a sequence of dataloaders. + replace_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the dataloader(s) + for distributed training. If you have a custom sampler defined, set this to this argument to ``False``. + move_to_device: If set ``True`` (default), moves the data returned by the dataloader(s) automatially to + the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually on the + returned data. + + Returns: + The wrapped dataloaders, in the same order they were passed in. + """ + self._validate_setup_dataloaders(dataloaders) + dataloaders = [ + self._setup_dataloader(dataloader, replace_sampler=replace_sampler, move_to_device=move_to_device) + for dataloader in dataloaders + ] + dataloaders = dataloaders[0] if len(dataloaders) == 1 else dataloaders + return dataloaders + + def _setup_dataloader( + self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True + ) -> Union[Iterable, DataLoader]: + """Setup a single dataloader for accelerated training. + + Args: + dataloader: The dataloader to accelerate. + replace_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the dataloader + for distributed training. If you have a custom sampler defined, set this to this argument to ``False``. + move_to_device: If set ``True`` (default), moves the data returned by the dataloader automatially to + the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually on the + returned data. + + Returns: + The wrapped dataloader. + """ + sampler = dataloader.sampler + if replace_sampler and self._requires_distributed_sampler(dataloader): + if not isinstance(sampler, (SequentialSampler, RandomSampler)): + raise MisconfigurationException( + "You seem to have configured a sampler in your DataLoader. This will be replaced " + " by `DistributedSampler` since `replace_sampler_ddp` is True and you are using" + " distributed training. Either remove the sampler from your DataLoader or set" + " `replace_sampler=False` if you want to use your custom sampler." + ) + sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs) + + kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler) + device = self.device if move_to_device else None + if isinstance(self._strategy, TPUSpawnPlugin): + dataloader = DataLoader(**kwargs) + else: + dataloader = _LiteDataLoader(device=device, **kwargs) + + # add worker_init_fn for correct seeding in worker processes + TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank) + + return self._strategy.process_dataloader(dataloader) + + def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None: + """Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you. + + Args: + tensor: The tensor (loss) to back-propagate gradients from. + *args: Optional positional arguments passed to the underlying backward function. + model: Optional model instance for plugins that require the model for backward(). + **kwargs: Optional named keyword arguments passed to the underlying backward function. + + Note: + When using ``strategy="deepspeed"`` and multiple models were setup, it is required to pass in the + model as argument here. + """ + module = model.module if model is not None else model + if isinstance(self._strategy, DeepSpeedPlugin): + if model is None: + if self._models_setup == 0: + raise MisconfigurationException( + "No models were setup for backward. Did you forget to call `self.setup()`?" + ) + if self._models_setup > 1: + raise MisconfigurationException( + "When using multiple models + deepspeed, please provide the model used to perform" + " the optimization: `self.backward(loss, model=model)`" + ) + module = self._strategy.model + else: + # requires to attach the current `DeepSpeedEngine` for the `_LiteOptimizer.step` call. + self._strategy.model = module + + self._precision_plugin._run_backward(tensor, module, *args, **kwargs) + + @contextmanager + def autocast(self) -> Generator[None, None, None]: + """A context manager to automatically convert operations for the chosen precision. + + Use this only if the `forward` method of your model does not cover all operations you wish to run with the + chosen precision setting. + """ + with self._precision_plugin.forward_context(): + yield + + def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]: + """Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already + on that device. + + Args: + obj: An object to move to the device. Can be an instance of :class:`torch.nn.Module`, a tensor, or a + (nested) collection of tensors (e.g., a dictionary). + + Returns: + A reference to the object that was moved to the new device. + """ + if isinstance(obj, nn.Module): + if self.device.type == "cuda": + # need to call this manually here again in case we spawned with DDPSpawnPlugin + # TODO: refactor to let plugin handle this cleanly + torch.cuda.set_device(self.device) + return obj.to(self.device) + return move_data_to_device(obj, device=self.device) + + def print(self, *args: Any, **kwargs: Any) -> None: + """Print something only on the first process. + + Arguments passed to this method are forwarded to the Python built-in :func:`print` function. + """ + if self.local_rank == 0: + print(*args, **kwargs) + + def barrier(self, name: Optional[str] = None) -> None: + """Wait for all processes to enter this call. Use this to synchronize all parallel processes, but only if + necessary, otherwise the overhead of synchronization will cause your program to slow down. + + Example:: + + if self.global_rank == 0: + # let process 0 download the dataset + dataset.download_files() + + # let all processes wait before reading the dataset + self.barrier() + + # now all processes can read the files and start training + """ + self._strategy.barrier(name=name) + + def all_gather( + self, data: Union[torch.Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False + ) -> Union[torch.Tensor, Dict, List, Tuple]: + r""" + Gather tensors or collections of tensors from multiple processes. + + Args: + data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof. + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for the all_gather operation + + Return: + A tensor of shape (world_size, batch, ...), or if the input was a collection + the output will also be a collection with tensors of this shape. + """ + group = group if group is not None else torch.distributed.group.WORLD + data = convert_to_tensors(data, device=self.device) + return apply_to_collection(data, torch.Tensor, self._strategy.all_gather, group=group, sync_grads=sync_grads) + + def broadcast(self, obj: object, src: int = 0) -> object: + return self._strategy.broadcast(obj, src=src) + + def save(self, content: Dict[str, Any], filepath: Union[str, Path]) -> None: + """Save checkpoint contents to a file. + + How and which processes save gets determined by the `strategy`. For example, the `ddp` strategy + saves checkpoints only on process 0. + + Args: + content: A dictionary with contents, i.e., the state dict of your model + filepath: A path to where the file should be saved + """ + self._strategy.save_checkpoint(content, filepath) + + def load(self, filepath: Union[str, Path]) -> Any: + """Load a checkpoint from a file. + + How and which processes load gets determined by the `strategy` + + Args: + filepath: A path to where the file is located + """ + return self._strategy.load_checkpoint(filepath) + + @staticmethod + def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) -> int: + """Helper function to seed everything without explicitly importing Lightning. + + See :func:`pytorch_lightning.seed_everything` for more details. + """ + if workers is None: + # Lightning sets `workers=False` by default to avoid breaking reproducibility, but since this is a new + # release, we can afford to do it. + workers = True + return seed_everything(seed=seed, workers=workers) + + def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: + self._set_plugin_specific_precision_variables() + self._accelerator.setup_environment() + + # apply sharded context to prevent OOM + run_method = partial(self._run_with_sharded_context, run_method) + + if isinstance(self._strategy, DDPSpawnPlugin): + return self._strategy.spawn(run_method, *args, return_result=True, **kwargs) + else: + return run_method(*args, **kwargs) + + def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: + with self._strategy.model_sharded_context(): + return run_method(*args, **kwargs) + + def _set_plugin_specific_precision_variables(self) -> None: + # todo: these are hacks as plugins rely on access to the precision plugin + if isinstance(self._strategy, DeepSpeedPlugin): + self._set_deepspeed_precision_variables() + if isinstance(self._strategy, DDPShardedPlugin): + self._strategy._precision = self._accelerator_connector.precision + + def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: + if isinstance(self._strategy, TPUSpawnPlugin): + # When the user creates the optimizer, they reference the parameters on the CPU. + # However, when running with TPU the parameters get copied and the reference in the optimizer + # remains invalid. We need to update the references to point to the parameter tensors on the device. + params_before_move = dict(model.named_parameters()) + model = self.to_device(model) + # XLA makes a copy on the parameters, so the device is not the same before and after to_device. + params_on_device = dict(model.named_parameters()) + + mapping = {param: params_on_device[name] for name, param in params_before_move.items()} + for optimizer in optimizers: + for param_group in optimizer.param_groups: + param_group["params"] = [mapping.get(p, p) for p in param_group["params"]] + else: + model = self.to_device(model) + return model + + def _set_deepspeed_precision_variables(self) -> None: + # TODO: Refactor this once precision pluging is part of the strategy. + amp_type = self._accelerator_connector.amp_type + amp_level = self._accelerator_connector.amp_level + precision = self._accelerator_connector.precision + self._strategy._amp_level, self._strategy._amp_type, self._strategy._precision = amp_level, amp_type, precision + + def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool: + return ( + self._accelerator_connector.is_distributed + and not isinstance(dataloader.sampler, DistributedSampler) + and not has_iterable_dataset(dataloader) + ) + + @staticmethod + def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> DistributedSampler: + kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0))) + return DistributedSampler(dataloader.dataset, **kwargs) + + def _check_accelerator_support(self, accelerator: Optional[Union[str, Accelerator]]) -> None: + supported = [t.value.lower() for t in self._supported_device_types()] + ["auto"] + valid = accelerator is None or isinstance(accelerator, Accelerator) or accelerator in supported + if not valid: + raise MisconfigurationException( + f"`accelerator={repr(accelerator)}` is not a valid choice." + f" Choose one of {supported} or pass in a `Accelerator` instance." + ) + + def _check_strategy_support(self, strategy: Optional[Union[str, TrainingTypePlugin]]) -> None: + supported = [t.lower() for t in self._supported_strategy_types()] + valid = strategy is None or isinstance(strategy, TrainingTypePlugin) or strategy in supported + if not valid: + raise MisconfigurationException( + f"`strategy={repr(strategy)}` is not a valid choice." + f" Choose one of {supported} or pass in a `TrainingTypePlugin` instance." + ) + + @staticmethod + def _supported_device_types() -> Sequence[DeviceType]: + return ( + DeviceType.CPU, + DeviceType.GPU, + DeviceType.TPU, + ) + + @staticmethod + def _supported_strategy_types() -> Sequence[DistributedType]: + return ( + DistributedType.DP, + DistributedType.DDP, + DistributedType.DDP_SPAWN, + DistributedType.DEEPSPEED, + DistributedType.DDP_SHARDED, + DistributedType.DDP_SHARDED_SPAWN, + ) + + @staticmethod + def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None: + if isinstance(model, _LiteModule): + raise MisconfigurationException("A model should be passed only once to the `setup` method.") + + if any(isinstance(opt, _LiteOptimizer) for opt in optimizers): + raise MisconfigurationException("An optimizer should be passed only once to the `setup` method.") + + @staticmethod + def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None: + if any(isinstance(dl, _LiteDataLoader) for dl in dataloaders): + raise MisconfigurationException("A dataloader should be passed only once to the `setup_dataloaders` method") + + if any(not isinstance(dl, DataLoader) for dl in dataloaders): + raise MisconfigurationException("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.") diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py new file mode 100644 index 0000000000000..3dd387319ae68 --- /dev/null +++ b/pytorch_lightning/lite/wrappers.py @@ -0,0 +1,126 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Generator, Iterator, Optional, Union + +import torch +from torch import nn as nn +from torch import Tensor +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.plugins import PrecisionPlugin +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device + + +def _do_nothing_closure() -> None: + return None + + +class _LiteOptimizer: + def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None: + """LiteOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer + step calls to the accelerator/strategy plugin. + + The underlying wrapped optimizer object can be accessed via the property :attr:`optimizer`. + + Args: + optimizer: The optimizer to wrap + accelerator: Reference to the accelerator for handling the optimizer step + """ + # `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would + # not want to call on destruction of the `_LiteOptimizer + self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")} + self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) + self._optimizer = optimizer + self._accelerator = accelerator + + @property + def optimizer(self) -> Optimizer: + return self._optimizer + + def step(self, closure: Optional[Callable] = None) -> None: + closure = closure or _do_nothing_closure + self._accelerator.optimizer_step( + self.optimizer, + opt_idx=0, + closure=closure, + model=self._accelerator.model, + ) + + +class _LiteModule(nn.Module): + def __init__(self, module: nn.Module, precision_plugin: PrecisionPlugin) -> None: + """The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast + automatically for the forward pass. + + The underlying wrapped module can be accessed via the property :attr:`module`. + + Args: + module: The module to wrap + precision_plugin: Reference to the precision plugin for handling precision context + """ + super().__init__() + self._module = module + self._precision_plugin = precision_plugin + + @property + def module(self) -> nn.Module: + return self._module + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Casts all inputs to the right precision and handles autocast for operations in the module forward + method.""" + precision = self._precision_plugin.precision + precision_to_type = { + "bf16": torch.bfloat16, + 16: torch.float16, + 32: torch.float32, + 64: torch.float64, + } + # TODO (@awaelchli): let the precision plugin handle the conversion + to_type = precision_to_type[precision] + args, kwargs = apply_to_collection([args, kwargs], function=lambda t: t.to(to_type), dtype=Tensor) + + with self._precision_plugin.forward_context(): + output = self.module(*args, **kwargs) + + output = apply_to_collection(output, function=lambda t: t.to(torch.get_default_dtype()), dtype=Tensor) + return output + + +class _LiteDataLoader(DataLoader): + def __init__(self, device: Optional[torch.device] = None, **dl_kwargs: Any) -> None: + """The LiteDataLoader is an extension of the PyTorch :class:`~torch.utils.data.DataLoader` that adds + additional features such as moving the data to the device automatically. + + Args: + device: The device to which the data should be moved. By default the device is `None` and no data + transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`). + **dl_kwargs: Accepts all arguments that the PyTorch :class:`~torch.utils.data.DataLoader` accepts. + """ + super().__init__(**dl_kwargs) + self._device = device + + @property + def device(self) -> Optional[torch.device]: + return self._device + + def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: + iterator = super().__iter__() + if self._device is None: + return iterator + + for item in iterator: + yield move_data_to_device(item, self._device) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 955f8973d3c18..1325c2c380134 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -336,7 +336,8 @@ def precision(self) -> Union[str, int]: @property def amp_level(self) -> Optional[str]: - return self._amp_level or self.lightning_module.trainer._accelerator_connector.amp_level + if self._amp_type == AMPType.APEX: + return self._amp_level or self.lightning_module.trainer._accelerator_connector.amp_level @property def amp_type(self) -> Optional[str]: diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 24206b8af1fc1..071eead5613b4 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -114,9 +114,10 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: " in the `DataLoader` init to improve performance." ) - def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None: + @staticmethod + def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None: if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None: - dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank) + dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) def _requires_distributed_sampler(self, dataloader) -> bool: return ( @@ -336,7 +337,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train_dataloader") # add worker_init_fn for correct seeding in worker processes - apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn) + apply_to_collection(self.train_dataloader, DataLoader, self._auto_add_worker_init_fn, rank=self.global_rank) # add collate_fn to collect metadata for fault tolerant training if _fault_tolerant_training(): @@ -443,7 +444,9 @@ def _reset_eval_dataloader( dataloaders = [self.prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None] # add worker_init_fn for correct seeding in worker processes - apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn) + apply_to_collection( + dataloaders, dtype=DataLoader, function=self._auto_add_worker_init_fn, rank=self.global_rank + ) loader_num_batches = [] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 373608c3c8b28..74522424c5326 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -559,7 +559,7 @@ def __init__( if gradient_clip_algorithm is not None else gradient_clip_algorithm ) - self.track_grad_norm = float(track_grad_norm) + self.track_grad_norm: float = float(track_grad_norm) self._detect_anomaly: bool = detect_anomaly self._setup_on_init(num_sanity_val_steps) diff --git a/tests/helpers/pipelines.py b/tests/helpers/pipelines.py index 3e5066d708da0..643d3e50cb894 100644 --- a/tests/helpers/pipelines.py +++ b/tests/helpers/pipelines.py @@ -67,7 +67,7 @@ def run_model_test( assert trainer.state.finished, f"Training failed with {trainer.state}" # Check that the model is actually changed post-training change_ratio = torch.norm(initial_values - post_train_values) - assert change_ratio > 0.1, f"the model is changed of {change_ratio}" + assert change_ratio > 0.03, f"the model is changed of {change_ratio}" # test model loading pretrained_model = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model)) diff --git a/tests/lite/__init__.py b/tests/lite/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py new file mode 100644 index 0000000000000..916e0aa542b32 --- /dev/null +++ b/tests/lite/test_lite.py @@ -0,0 +1,413 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from copy import deepcopy +from unittest import mock +from unittest.mock import MagicMock, Mock, PropertyMock + +import pytest +import torch +import torch.distributed +import torch.nn.functional +from torch import nn +from torch.utils.data import DataLoader, DistributedSampler, Sampler + +from pytorch_lightning.lite import LightningLite +from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer +from pytorch_lightning.plugins import DeepSpeedPlugin, PrecisionPlugin, TrainingTypePlugin +from pytorch_lightning.utilities import DistributedType +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.seed import pl_worker_init_function +from tests.helpers.runif import RunIf + + +class EmptyLite(LightningLite): + def run(self): + pass + + +class BoringModel(nn.Module): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2, bias=False) + + def forward(self, x): + x = self.layer(x) + return torch.nn.functional.mse_loss(x, torch.ones_like(x)) + + +def test_unsupported_accelerator(): + accelerator = "coconut" + with pytest.raises(MisconfigurationException, match=f"`accelerator={repr(accelerator)}` is not a valid choice"): + EmptyLite(accelerator=accelerator) + + +def test_unsupported_strategy(): + strategy = "coconut" + with pytest.raises(MisconfigurationException, match=f"`strategy={repr(strategy)}` is not a valid choice"): + EmptyLite(strategy=strategy) + + +def test_run_input_output(): + """Test that the dynamically patched run() method receives the input arguments and returns the result.""" + + class Lite(LightningLite): + + run_args = () + run_kwargs = {} + + def run(self, *args, **kwargs): + self.run_args = args + self.run_kwargs = kwargs + return "result" + + lite = Lite() + result = lite.run(1, 2, three=3) + assert result == "result" + assert lite.run_args == (1, 2) + assert lite.run_kwargs == {"three": 3} + + +def test_setup_optimizers(): + """Test that setup_optimizers can handle no optimizers, one optimizer, or multiple optimizers.""" + lite = EmptyLite() + model = nn.Linear(1, 2) + optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1) + optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1) + + # no optimizer + lite_model = lite.setup(model) + assert isinstance(lite_model, _LiteModule) + assert lite_model.module is model + + # single optimizer + lite_model, lite_optimizer = lite.setup(model, optimizer0) + assert isinstance(lite_model, _LiteModule) + assert isinstance(lite_optimizer, _LiteOptimizer) + assert lite_model.module is model + assert lite_optimizer.optimizer is optimizer0 + + # multiple optimizers + lite_model, lite_optimizer0, lite_optimizer1 = lite.setup(model, optimizer0, optimizer1) + assert isinstance(lite_model, _LiteModule) + assert isinstance(lite_optimizer0, _LiteOptimizer) + assert isinstance(lite_optimizer1, _LiteOptimizer) + assert lite_model.module is model + assert lite_optimizer0.optimizer is optimizer0 + assert lite_optimizer1.optimizer is optimizer1 + + +def test_setup_twice_fails(): + """Test that calling setup with a model or optimizer that is already wrapped fails.""" + lite = EmptyLite() + model = nn.Linear(1, 2) + optimizer = torch.optim.Adam(model.parameters()) + + lite_model, lite_optimizer = lite.setup(model, optimizer) + with pytest.raises(MisconfigurationException, match="A model should be passed only once to the"): + lite.setup(lite_model, optimizer) + + lite_model, lite_optimizer = lite.setup(model, optimizer) + with pytest.raises(MisconfigurationException, match="An optimizer should be passed only once to the"): + lite.setup(model, lite_optimizer) + + +def test_setup_tracks_num_models(): + """Test that setup() tracks how many times it has setup a model.""" + lite = EmptyLite() + model = nn.Linear(1, 2) + optimizer = torch.optim.Adam(model.parameters()) + + assert lite._models_setup == 0 + lite.setup(model, optimizer) + assert lite._models_setup == 1 + + lite.setup(model, optimizer) + assert lite._models_setup == 2 + + +def test_setup_dataloaders_unsupported_type(): + """Test that the setup_dataloaders method fails when provided with non-DataLoader objects.""" + lite = EmptyLite() + with pytest.raises(MisconfigurationException, match="Only PyTorch DataLoader are currently supported"): + lite.setup_dataloaders(range(2)) # type: ignore + + +def test_setup_dataloaders_return_type(): + """Test that the setup method returns the dataloaders wrapped as LiteDataLoader and in the right order.""" + lite = EmptyLite() + + # single dataloader + lite_dataloader = lite.setup_dataloaders(DataLoader(range(2))) + assert isinstance(lite_dataloader, _LiteDataLoader) + + # multiple dataloaders + dataset0 = Mock() + dataset1 = Mock() + dataloader0 = DataLoader(dataset0) + dataloader1 = DataLoader(dataset1) + lite_dataloader0, lite_dataloader1 = lite.setup_dataloaders(dataloader0, dataloader1) + assert isinstance(lite_dataloader0, _LiteDataLoader) + assert isinstance(lite_dataloader1, _LiteDataLoader) + assert lite_dataloader0.dataset is dataset0 + assert lite_dataloader1.dataset is dataset1 + + +def test_setup_dataloaders_twice_fails(): + """Test that calling setup_dataloaders with a dataloader that is already wrapped fails.""" + lite = EmptyLite() + dataloader = DataLoader(range(2)) + lite_dataloader = lite.setup_dataloaders(dataloader) + + with pytest.raises(MisconfigurationException, match="A dataloader should be passed only once to the"): + lite.setup_dataloaders(lite_dataloader) + + +@mock.patch( + "pytorch_lightning.lite.lite.LightningLite.device", + new_callable=PropertyMock, + return_value=torch.device("cuda", 1), +) +def test_setup_dataloaders_move_to_device(lite_device_mock): + """Test that the setup configures LiteDataLoader to move the data to the device automatically.""" + lite = EmptyLite() + lite_dataloaders = lite.setup_dataloaders(DataLoader(Mock()), DataLoader(Mock()), move_to_device=False) + assert all(dl.device is None for dl in lite_dataloaders) + lite_device_mock.assert_not_called() + + lite = EmptyLite() + lite_dataloaders = lite.setup_dataloaders(DataLoader(Mock()), DataLoader(Mock()), move_to_device=True) + assert all(dl.device == torch.device("cuda", 1) for dl in lite_dataloaders) + lite_device_mock.assert_called() + + +def test_setup_dataloaders_distributed_sampler_not_needed(): + """Test that replace_sampler option has no effect when no distributed sampler is needed.""" + custom_sampler = Mock(spec=Sampler) + dataloader = DataLoader(Mock(), sampler=custom_sampler) + + # keep the custom sampler when not needed to replace + lite = EmptyLite() + lite_dataloader = lite.setup_dataloaders(dataloader, replace_sampler=True) + assert lite_dataloader.sampler is custom_sampler + + +@mock.patch.dict(os.environ, {}, clear=True) +def test_seed_everything(): + """Test that seed everything is static and sets the worker init function on the dataloader.""" + EmptyLite.seed_everything(3) + + lite = EmptyLite() + lite_dataloader = lite.setup_dataloaders(DataLoader(Mock())) + + assert lite_dataloader.worker_init_fn.func is pl_worker_init_function + assert os.environ == {"PL_GLOBAL_SEED": "3", "PL_SEED_WORKERS": "1"} + + +@pytest.mark.parametrize( + "strategy", + [ + DistributedType.DP, + DistributedType.DDP, + DistributedType.DDP_SPAWN, + pytest.param(DistributedType.DEEPSPEED, marks=RunIf(deepspeed=True)), + pytest.param(DistributedType.DDP_SHARDED, marks=RunIf(fairscale=True)), + pytest.param(DistributedType.DDP_SHARDED_SPAWN, marks=RunIf(fairscale=True)), + ], +) +def test_setup_dataloaders_replace_custom_sampler(strategy): + """Test that asking to replace a custom sampler results in an error when a distributed sampler would be + needed.""" + custom_sampler = Mock(spec=Sampler) + dataloader = DataLoader(Mock(), sampler=custom_sampler) + + # explicitly asking to replace when a custom sampler is already configured raises an exception + lite = EmptyLite(accelerator="cpu", strategy=strategy, devices=2) + if lite._accelerator_connector.is_distributed: + with pytest.raises(MisconfigurationException, match="You seem to have configured a sampler in your DataLoader"): + lite.setup_dataloaders(dataloader, replace_sampler=True) + + # setting `replace_sampler=False` leaves the sampler untouched + lite_dataloader = lite.setup_dataloaders(dataloader, replace_sampler=False) + assert lite_dataloader.sampler is custom_sampler + + +@pytest.mark.parametrize( + "strategy", + [ + DistributedType.DP, + DistributedType.DDP, + DistributedType.DDP_SPAWN, + pytest.param(DistributedType.DEEPSPEED, marks=RunIf(deepspeed=True)), + pytest.param(DistributedType.DDP_SHARDED, marks=RunIf(fairscale=True)), + pytest.param(DistributedType.DDP_SHARDED_SPAWN, marks=RunIf(fairscale=True)), + ], +) +@pytest.mark.parametrize("shuffle", [True, False]) +def test_setup_dataloaders_replace_standard_sampler(shuffle, strategy): + """Test that Lite replaces the default samplers with DistributedSampler automatically.""" + lite = EmptyLite(accelerator="cpu", strategy=strategy, devices=2) + is_distributed = lite._accelerator_connector.is_distributed + lite_dataloader = lite.setup_dataloaders(DataLoader(range(3), shuffle=shuffle)) + assert not is_distributed or isinstance(lite_dataloader.sampler, DistributedSampler) + + +@pytest.mark.parametrize( + "accelerator, expected", + [ + ("cpu", torch.device("cpu")), + pytest.param("gpu", torch.device("cuda", 0), marks=RunIf(min_gpus=1)), + pytest.param("tpu", torch.device("xla", 0), marks=RunIf(tpu=True)), + ], +) +def test_to_device(accelerator, expected): + """Test that the to_device method can move various objects to the device determined by the accelerator.""" + lite = EmptyLite(accelerator=accelerator, devices=1) + + # module + module = torch.nn.Linear(2, 3) + module = lite.to_device(module) + assert all(param.device == expected for param in module.parameters()) + + # tensor + tensor = torch.rand(2, 2) + tensor = lite.to_device(tensor) + assert tensor.device == expected + + # collection + collection = {"data": torch.rand(2, 2), "int": 1} + collection = lite.to_device(collection) + assert collection["data"].device == expected + + +def test_rank_properties(): + """Test that the rank properties are determined by the strategy.""" + lite = EmptyLite() + lite._strategy = Mock(spec=TrainingTypePlugin) + lite._strategy.world_size = 1000 + assert lite.world_size == 1000 + lite._strategy.global_rank = 100 + assert lite.global_rank == 100 + lite._strategy.local_rank = 10 + assert lite.local_rank == 10 + lite._strategy.node_rank = 1 + assert lite.node_rank == 1 + + +def test_backward(): + """Test that backward() calls into the precision plugin.""" + lite = EmptyLite() + lite._precision_plugin = Mock(spec=PrecisionPlugin) + loss = Mock() + lite.backward(loss, "arg", keyword="kwarg") + lite._precision_plugin._run_backward.assert_called_with(loss, None, "arg", keyword="kwarg") + + +@RunIf(deepspeed=True) +def test_backward_model_input_required(): + """Test that when using deepspeed and multiple models, backward() requires the model as input.""" + lite = EmptyLite(strategy="deepspeed") + + model0 = nn.Linear(1, 2) + model1 = nn.Linear(1, 2) + + optimizer0 = torch.optim.Adam(model0.parameters()) + optimizer1 = torch.optim.Adam(model1.parameters()) + + lite._strategy._setup_model_and_optimizer = lambda *args: args + + lite.setup(model0, optimizer0) + lite.setup(model1, optimizer1) + + loss = model0(torch.randn(1, 1)).sum() + + with pytest.raises(MisconfigurationException, match="please provide the model used to perform"): + lite.backward(loss) + + +def test_autocast(): + """Test that the Lite autocast context manager lets the precision plugin handle casting.""" + lite = EmptyLite() + lite._precision_plugin.forward_context = MagicMock() + + lite._precision_plugin.forward_context().__enter__.assert_not_called() + with lite.autocast(): + lite._precision_plugin.forward_context().__enter__.assert_called() + lite._precision_plugin.forward_context().__exit__.assert_called() + + +@RunIf(min_gpus=2, deepspeed=True, special=True) +def test_deepspeed_multiple_models(): + class Lite(LightningLite): + def run(self): + model = BoringModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) + model, optimizer = self.setup(model, optimizer) + state_dict = deepcopy(model.state_dict()) + + for _ in range(2): + optimizer.zero_grad() + x = model(torch.randn(1, 32).to(self.device)) + loss = x.sum() + self.backward(loss, model=model) + optimizer.step() + + for mw_b, mw_a in zip(state_dict.values(), model.state_dict().values()): + assert not torch.equal(mw_b, mw_a) + + self.seed_everything(42) + model_1 = BoringModel() + optimizer_1 = torch.optim.SGD(model_1.parameters(), lr=0.0001) + + self.seed_everything(42) + model_2 = BoringModel() + optimizer_2 = torch.optim.SGD(model_2.parameters(), lr=0.0001) + + for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()): + assert torch.equal(mw_1, mw_2) + + model_1, optimizer_1 = self.setup(model_1, optimizer_1) + model_2, optimizer_2 = self.setup(model_2, optimizer_2) + + self.seed_everything(42) + data_list = [] + for _ in range(2): + optimizer_1.zero_grad() + data = torch.randn(1, 32).to(self.device) + data_list.append(data) + x = model_1(data) + loss = x.sum() + self.backward(loss, model=model_1) + optimizer_1.step() + + for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()): + assert not torch.equal(mw_1, mw_2) + + for data in data_list: + optimizer_2.zero_grad() + x = model_2(data) + loss = x.sum() + self.backward(loss, model=model_2) + optimizer_2.step() + + for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()): + assert torch.equal(mw_1, mw_2) + + # Verify collectives works as expected + ranks = self.all_gather(torch.tensor([self.local_rank]).to(self.device)) + assert torch.equal(ranks.cpu(), torch.tensor([[0], [1]])) + assert self.broadcast(True) + assert self.is_global_zero == (self.local_rank == 0) + + Lite(strategy=DeepSpeedPlugin(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run() diff --git a/tests/lite/test_parity.py b/tests/lite/test_parity.py new file mode 100644 index 0000000000000..bec9339ec8e2f --- /dev/null +++ b/tests/lite/test_parity.py @@ -0,0 +1,222 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from contextlib import contextmanager +from copy import deepcopy +from functools import partial +from typing import Callable, Generator + +import pytest +import torch +import torch.distributed +import torch.multiprocessing as mp +import torch.nn.functional +from torch import nn +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from pytorch_lightning.lite import LightningLite +from pytorch_lightning.plugins.environments.lightning_environment import find_free_network_port +from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device +from pytorch_lightning.utilities.cloud_io import atomic_save +from tests.helpers.boring_model import RandomDataset +from tests.helpers.runif import RunIf + + +class BoringModel(nn.Module): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2, bias=False) + + def forward(self, x): + x = self.layer(x) + return torch.nn.functional.mse_loss(x, torch.ones_like(x)) + + +def configure_optimizers(module: nn.Module): + return torch.optim.SGD(module.parameters(), lr=0.0001) + + +def main( + move_to_device: Callable, + model: nn.Module, + train_dataloader: DataLoader, + num_epochs: int = 10, +): + model = move_to_device(model) + optimizer = configure_optimizers(model) + + for _ in range(num_epochs): + model.train() + for batch in train_dataloader: + batch = move_to_device(batch) + optimizer.zero_grad() + loss = model(batch) + loss.backward() + optimizer.step() + + return model.state_dict() + + +class LiteRunner(LightningLite): + def run(self, model: nn.Module, train_dataloader: DataLoader, num_epochs: int = 10, tmpdir: str = None): + optimizer = configure_optimizers(model) + model, optimizer = self.setup(model, optimizer) + train_dataloader = self.setup_dataloaders(train_dataloader) + + model.train() + for _ in range(num_epochs): + for batch in train_dataloader: + batch = self.to_device(batch) + optimizer.zero_grad() + loss = model(batch) + self.backward(loss) + optimizer.step() + + if isinstance(self._strategy, DDPSpawnPlugin) and tmpdir and self.global_rank == 0: + checkpoint_path = os.path.join(tmpdir, "model.pt") + atomic_save(model.state_dict(), checkpoint_path) + return checkpoint_path + + +@contextmanager +def precision_context(precision, accelerator) -> Generator[None, None, None]: + if precision == 32: + yield + return + if accelerator == "gpu": + with torch.cuda.amp.autocast(): + yield + elif accelerator == "cpu": + with torch.cpu.amp.autocast(): + yield + + +@pytest.mark.parametrize( + "precision, strategy, devices, accelerator", + [ + pytest.param(32, None, 1, "cpu"), + pytest.param(32, None, 1, "gpu", marks=RunIf(min_gpus=1)), + pytest.param(16, None, 1, "gpu", marks=RunIf(min_gpus=1)), + pytest.param("bf16", None, 1, "gpu", marks=RunIf(min_torch="1.10", min_gpus=1)), + ], +) +def test_boring_lite_model_single_device(precision, strategy, devices, accelerator, tmpdir): + LightningLite.seed_everything(42) + train_dataloader = DataLoader(RandomDataset(32, 8)) + model = BoringModel() + num_epochs = 1 + state_dict = deepcopy(model.state_dict()) + + lite = LiteRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) + lite.run(model, train_dataloader, num_epochs=num_epochs) + lite_state_dict = model.state_dict() + + with precision_context(precision, accelerator): + model.load_state_dict(state_dict) + pure_state_dict = main(lite.to_device, model, train_dataloader, num_epochs=num_epochs) + + state_dict = apply_to_collection(state_dict, torch.Tensor, lite.to_device) + for w_pure, w_lite in zip(state_dict.values(), lite_state_dict.values()): + assert not torch.equal(w_pure, w_lite) + + for w_pure, w_lite in zip(pure_state_dict.values(), lite_state_dict.values()): + assert torch.equal(w_pure, w_lite) + + +def run(rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdir): + os.environ["LOCAL_RANK"] = str(rank) + if torch.distributed.is_available() and not torch.distributed.is_initialized(): + torch.distributed.init_process_group("gloo", rank=rank, world_size=2) + + to_device = partial(move_data_to_device, device=torch.device("cuda", rank)) + model = DistributedDataParallel( + to_device(model), + device_ids=[rank], + ) + train_dataloader = DataLoader( + train_dataloader.dataset, + sampler=DistributedSampler(train_dataloader.dataset, rank=rank, num_replicas=2, seed=42, drop_last=False), + ) + with precision_context(precision, accelerator): + main(to_device, model, train_dataloader, num_epochs=num_epochs) + + if rank == 0: + atomic_save(model.state_dict(), os.path.join(tmpdir, "model_spawn.pt")) + + +@pytest.mark.skipif(True, reason="Skipping as it takes 80 seconds.") +@RunIf(min_gpus=2) +@pytest.mark.parametrize( + "precision, strategy, devices, accelerator", + [ + (32, "ddp_spawn", 2, "gpu"), + ], +) +def test_boring_lite_model_ddp_spawn(precision, strategy, devices, accelerator, tmpdir): + LightningLite.seed_everything(42) + train_dataloader = DataLoader(RandomDataset(32, 8)) + model = BoringModel() + num_epochs = 1 + state_dict = deepcopy(model.state_dict()) + + lite = LiteRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) + checkpoint_path = lite.run(model, train_dataloader, num_epochs=num_epochs, tmpdir=tmpdir) + spawn_model_state_dict = torch.load(checkpoint_path) + + for w_pure, w_lite in zip(state_dict.values(), spawn_model_state_dict.values()): + assert not torch.equal(w_pure.cpu(), w_lite.cpu()) + + model.load_state_dict(state_dict) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(find_free_network_port()) + mp.spawn(run, args=(model, train_dataloader, num_epochs, precision, accelerator, tmpdir), nprocs=2) + spawn_pure_model_state_dict = torch.load(os.path.join(tmpdir, "model_spawn.pt")) + + for w_pure, w_lite in zip(spawn_pure_model_state_dict.values(), spawn_model_state_dict.values()): + assert torch.equal(w_pure.cpu(), w_lite.cpu()) + + +@RunIf(min_gpus=2, special=True) +@pytest.mark.parametrize( + "precision, strategy, devices, accelerator", + [ + (32, "ddp", 2, "gpu"), + ], +) +def test_boring_lite_model_ddp(precision, strategy, devices, accelerator, tmpdir): + LightningLite.seed_everything(42) + train_dataloader = DataLoader(RandomDataset(32, 4)) + model = BoringModel() + num_epochs = 1 + state_dict = deepcopy(model.state_dict()) + + lite = LiteRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) + lite.run(model, train_dataloader, num_epochs=num_epochs, tmpdir=tmpdir) + + lite_model_state_dict = model.state_dict() + + for w_pure, w_lite in zip(state_dict.values(), lite_model_state_dict.values()): + assert not torch.equal(w_pure.cpu(), w_lite.cpu()) + + LightningLite.seed_everything(42) + train_dataloader = DataLoader(RandomDataset(32, 4)) + model = BoringModel() + run(lite.global_rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdir) + pure_model_state_dict = model.state_dict() + + for w_pure, w_lite in zip(pure_model_state_dict.values(), lite_model_state_dict.values()): + assert torch.equal(w_pure.cpu(), w_lite.cpu()) diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py new file mode 100644 index 0000000000000..3e2e9ac7a9f9a --- /dev/null +++ b/tests/lite/test_wrappers.py @@ -0,0 +1,103 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import ANY, Mock + +import pytest +import torch + +from pytorch_lightning.lite import LightningLite +from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer +from tests.helpers.runif import RunIf + + +class EmptyLite(LightningLite): + def run(self): + pass + + +def test_lite_module_wraps(): + """Test that the wrapped module is accessible via the property.""" + module = Mock() + assert _LiteModule(module, Mock()).module is module + + +@RunIf(min_gpus=1) +@pytest.mark.parametrize( + "precision, input_type, expected_type", + [ + (32, torch.float16, torch.float32), + (32, torch.float32, torch.float32), + (32, torch.float64, torch.float32), + (16, torch.float32, torch.float16), + (16, torch.float64, torch.float16), + ], +) +def test_lite_module_forward_conversion(precision, input_type, expected_type): + """Test that the LiteModule performs autocasting on the input tensors and during forward().""" + lite = EmptyLite(precision=precision, accelerator="gpu", devices=1) + device = torch.device("cuda", 0) + + def check_autocast(forward_input): + assert precision != 16 or torch.is_autocast_enabled() + return forward_input + + module = Mock(wraps=torch.nn.Linear(1, 1), side_effect=check_autocast) + lite_module = _LiteModule(module, lite._precision_plugin).to(device) + out = lite_module(torch.rand(1, dtype=input_type, device=device)) + assert module.call_args[0][0].dtype == expected_type + assert out.dtype == torch.get_default_dtype() + + +@pytest.mark.parametrize( + "src_device, dest_device", + [ + (torch.device("cpu"), torch.device("cpu")), + pytest.param(torch.device("cpu"), torch.device("cuda", 0), marks=RunIf(min_gpus=1)), + pytest.param(torch.device("cuda", 0), torch.device("cpu"), marks=RunIf(min_gpus=1)), + ], +) +def test_lite_dataloader_device_placement(src_device, dest_device): + """Test that the LiteDataLoader moves data to the device in its iterator.""" + sample0 = torch.tensor(0, device=src_device) + sample1 = torch.tensor(1, device=src_device) + sample2 = {"data": torch.tensor(2, device=src_device)} + sample3 = {"data": torch.tensor(3, device=src_device)} + data = [sample0, sample1, sample2, sample3] + lite_dataloader = _LiteDataLoader(device=dest_device, dataset=data, batch_size=2) + iterator = iter(lite_dataloader) + + batch0 = next(iterator) + assert torch.equal(batch0, torch.tensor([0, 1], device=dest_device)) + + batch1 = next(iterator) + assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device)) + + +def test_lite_optimizer_wraps(): + """Test that the LiteOptimizer fully wraps the optimizer.""" + optimizer_cls = torch.optim.SGD + optimizer = Mock(spec=optimizer_cls) + lite_optimizer = _LiteOptimizer(optimizer, Mock()) + assert lite_optimizer.optimizer is optimizer + assert isinstance(lite_optimizer, optimizer_cls) + + +def test_lite_optimizer_steps(): + """Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer.""" + optimizer = Mock() + accelerator = Mock() + lite_optimizer = _LiteOptimizer(optimizer=optimizer, accelerator=accelerator) + lite_optimizer.step() + accelerator.optimizer_step.assert_called_once() + accelerator.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=accelerator.model) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 2793c71560a81..ea31dbaf7d0a1 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -771,24 +771,24 @@ def test_auto_add_worker_init_fn(): trainer = Trainer() # without pl.seed_everything() - trainer.auto_add_worker_init_fn(dataloader) + trainer._auto_add_worker_init_fn(dataloader, 0) assert dataloader.worker_init_fn is None # with forcefully avoiding it seed_everything(0, workers=False) - trainer.auto_add_worker_init_fn(dataloader) + trainer._auto_add_worker_init_fn(dataloader, 0) assert dataloader.worker_init_fn is None # when user already has a worker_init_fn user_function = _user_worker_init_fn dataloader.worker_init_fn = user_function - trainer.auto_add_worker_init_fn(dataloader) + trainer._auto_add_worker_init_fn(dataloader, 0) assert dataloader.worker_init_fn is user_function dataloader.worker_init_fn = None # main use case seed_everything(0, workers=True) - trainer.auto_add_worker_init_fn(dataloader) + trainer._auto_add_worker_init_fn(dataloader, 0) assert dataloader.worker_init_fn is not None