From 2163721488a21141dc163ef2b257539e46d7e3ca Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sat, 13 Feb 2021 12:06:48 +0000 Subject: [PATCH 01/58] Add initial deepspeed changes --- docs/source/advanced/multi_gpu.rst | 71 ++++++ pytorch_lightning/accelerators/accelerator.py | 6 +- .../accelerators/accelerator_connector.py | 42 +++- pytorch_lightning/plugins/__init__.py | 4 + .../plugins/precision/__init__.py | 1 + .../plugins/precision/deepspeed_precision.py | 46 ++++ .../plugins/training_type/__init__.py | 1 + .../plugins/training_type/deepspeed.py | 215 ++++++++++++++++ .../training_type/training_type_plugin.py | 9 +- pytorch_lightning/trainer/trainer.py | 7 +- pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/apply_func.py | 19 ++ pytorch_lightning/utilities/enums.py | 1 + pytorch_lightning/utilities/imports.py | 1 + requirements/extra.txt | 2 + tests/plugins/test_deepspeed_plugin.py | 232 ++++++++++++++++++ 16 files changed, 644 insertions(+), 14 deletions(-) create mode 100644 pytorch_lightning/plugins/precision/deepspeed_precision.py create mode 100644 pytorch_lightning/plugins/training_type/deepspeed.py create mode 100644 tests/plugins/test_deepspeed_plugin.py diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 6619eed0209c6..0fbe0eea6236f 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -601,6 +601,8 @@ Lightning currently offers the following methods to leverage model parallelism: - Sharded Training (partitioning your gradients and optimizer state across multiple GPUs, for reduced memory overhead with **no performance loss**) - Sequential Model Parallelism with Checkpointing (partition your :class:`nn.Sequential ` module across multiple GPUs, leverage checkpointing and microbatching for further memory improvements and device utilization) +.. _sharded: + Sharded Training ^^^^^^^^^^^^^^^^ Lightning integration of optimizer sharded training provided by `FairScale `_. @@ -666,6 +668,75 @@ Sharded Training can work across all DDP variants by adding the additional ``--p Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required. +---------- + +.. _deepspeed: + +DeepSpeed +^^^^^^^^^ +`DeepSpeed `_ offers additional CUDA deep learning training optimizations, similar to `FairScale `_. DeepSpeed offers lower level training optimizations, and useful memory optimized optimizers such as 1-bit Adam. +We recommend using DeepSpeed in environments where speed and memory optimizations are important (such as training large billion parameter models), and where sacrificing flexibility as a tradeoff is acceptable. In addition, we recommend trying :doc:`sharded` first before trying DeepSpeed's further optimizations. + +To use DeepSpeed, you first need to install DeepSpeed using the commands below. + +.. code-block:: bash + + pip install mpi4py deepspeed + +If you run into an issue with the install or later in training, ensure that the CUDA version of the pytorch you've installed matches your locally installed CUDA (you can see which one has been recognized by running ``nvidia-smi``). +Additionally if you run into any issues installing m4py, ensure you have openmpi installed using ``sudo apt install libopenmpi-dev`` or ``brew install mpich`` before running ``pip install mpi4py``. + +Below we show an example of running `ZeRO Offload `_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption. +For even more speed benefit, they offer an optimized CPU version of ADAM to run the offloaded computation, which is faster than the standard PyTorch implementation. + +DeepSpeed requires that the optimizers and schedulers are defined within a config file. We've included the config to enable ZeRO-Offload. + +.. code-block:: python + + from pytorch_lightning import Trainer + + deepspeed_config = { + "optimizer": { + "type": "Adam", + "params": { + "lr": 3e-5, + "betas": [0.998, 0.999], + "eps": 1e-5, + "weight_decay": 1e-9, + }, + }, + 'scheduler': { + "type": "WarmupLR", + "params": { + "last_batch_iteration": -1, + "warmup_min_lr": 0, + "warmup_max_lr": 3e-5, + "warmup_num_steps": 100, + } + }, + "zero_optimization": { + "stage": 2, + "cpu_offload": True, + "contiguous_gradients": True, + "overlap_comm": True + } + } + + model = MyModel() + trainer = Trainer(accelerator='deepspeed', gpus=4, deepspeed_config=deepspeed_config, precision=16) # zero offload requires mixed precision + trainer.fit(model) + +We also support taking the config as a json formatted file. + +.. code-block:: python + + from pytorch_lightning import Trainer + + model = MyModel() + trainer = Trainer(accelerator='deepspeed', gpus=4, deepspeed_config="/path/to/deepspeed_config.json", precision=16) # zero offload requires mixed precision + trainer.fit(model) + + ---------- .. _sequential-parallelism: diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index e348a57b5c103..9c747da6af504 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -279,7 +279,7 @@ def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Cal self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs) def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs): - optimizer.step(closure=lambda_closure, **kwargs) + self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs) def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None: """Zeros all model parameter's gradients""" @@ -312,7 +312,9 @@ def setup_optimizers(self, trainer: "Trainer"): """ if trainer.testing is True: return - optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(self.lightning_module) + optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers( + trainer=trainer, model=self.lightning_module + ) self.optimizers = optimizers self.lr_schedulers = lr_schedulers self.optimizer_frequencies = optimizer_frequencies diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index cfa9545ad6aee..37731b4e128fc 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -30,6 +30,8 @@ DDPShardedPlugin, DDPSpawnPlugin, DDPSpawnShardedPlugin, + DeepSpeedPlugin, + DeepSpeedPrecisionPlugin, HorovodPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin, @@ -71,6 +73,7 @@ def __init__( gpus, num_nodes, sync_batchnorm, + deepspeed_config, benchmark, replace_sampler_ddp, deterministic, @@ -90,6 +93,7 @@ def __init__( self.gpus = gpus self.num_nodes = num_nodes self.sync_batchnorm = sync_batchnorm + self.deepspeed_config = deepspeed_config self.benchmark = benchmark self.replace_sampler_ddp = replace_sampler_ddp self.deterministic = deterministic @@ -243,7 +247,7 @@ def use_dp(self) -> bool: def use_ddp(self) -> bool: return self._distrib_type in ( DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED, - DistributedType.DDP_SHARDED_SPAWN + DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED ) @property @@ -254,6 +258,10 @@ def use_ddp2(self) -> bool: def use_horovod(self) -> bool: return self._distrib_type == DistributedType.HOROVOD + @property + def use_deepspeed(self) -> bool: + return self._distrib_type == DistributedType.DEEPSPEED + @property def is_distributed(self) -> bool: is_distributed = self.use_ddp or self.use_ddp2 or self.use_horovod @@ -290,15 +298,19 @@ def is_using_torchelastic(self) -> bool: return te_flags_passed def select_precision_plugin(self) -> PrecisionPlugin: + self._set_precision_type() + + if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin): + return DeepSpeedPrecisionPlugin(self.precision) + if self.precision == 32: - self.amp_type = None return PrecisionPlugin() elif self.precision == 16: if self.on_tpu: return TPUHalfPrecisionPlugin() - if self.amp_type == "native": + if self.amp_type == AMPType.NATIVE: if not _NATIVE_AMP_AVAILABLE: rank_zero_warn( "You have asked for native AMP but your PyTorch version does not support it." @@ -309,7 +321,7 @@ def select_precision_plugin(self) -> PrecisionPlugin: raise MisconfigurationException( "You have asked for native AMP on CPU, but AMP is only available on GPU." ) - self.amp_type = "apex" + self.amp_type = AMPType.NATIVE elif self.on_cpu: raise MisconfigurationException( "You have asked for native AMP on CPU, but AMP is only available on GPU." @@ -318,10 +330,9 @@ def select_precision_plugin(self) -> PrecisionPlugin: log.info("Using native 16bit precision.") if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): return ShardedNativeMixedPrecisionPlugin() - self.amp_type = AMPType.NATIVE return NativeMixedPrecisionPlugin() - if self.amp_type == "apex": + if self.amp_type == AMPType.APEX: if not _APEX_AVAILABLE: rank_zero_warn( "You have asked for Apex AMP but you have not installed it yet." @@ -334,14 +345,29 @@ def select_precision_plugin(self) -> PrecisionPlugin: "please using native AMP for 16-bit precision." ) log.info("Using APEX 16bit precision.") - self.amp_type = AMPType.APEX return ApexMixedPrecisionPlugin(self.amp_level) else: raise NotImplementedError("We only support precisions 32 and 16!") - def select_training_type_plugin(self) -> TrainingTypePlugin: + def _set_precision_type(self): + if self.precision == 32: + self.amp_type = None + elif self.precision == 16: + if self.amp_type == 'amp': + self.amp_type = AMPType.NATIVE + elif self.amp_type == 'apex': + self.amp_type = AMPType.APEX + + def select_training_type_plugin(self): if self.use_ddp2: plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment) + elif self.use_ddp and self.use_deepspeed: + plugin = DeepSpeedPlugin( + config=self.deepspeed_config, + num_nodes=self.num_nodes, + cluster_environment=self.select_cluster_environment(), + parallel_devices=self.parallel_devices + ) elif self.use_ddp: use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index 2d9086c2e18ad..dec672d025294 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,5 +1,6 @@ from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401 from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401 @@ -7,6 +8,7 @@ from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401 from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 @@ -25,6 +27,8 @@ "DDP2Plugin", "DDPPlugin", "DDPSpawnPlugin", + "DeepSpeedPlugin", + "DeepSpeedPrecisionPlugin", "HorovodPlugin", "NativeMixedPrecisionPlugin", "PrecisionPlugin", diff --git a/pytorch_lightning/plugins/precision/__init__.py b/pytorch_lightning/plugins/precision/__init__.py index 1b085c92aafd6..fc60deffcbb77 100644 --- a/pytorch_lightning/plugins/precision/__init__.py +++ b/pytorch_lightning/plugins/precision/__init__.py @@ -1,4 +1,5 @@ from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py new file mode 100644 index 0000000000000..d238d64e3cd10 --- /dev/null +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -0,0 +1,46 @@ +from typing import Callable, Union + +import torch +from torch.optim import Optimizer + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin + + +class DeepSpeedPrecisionPlugin(PrecisionPlugin): + + def __init__(self, precision): + super().__init__() + self.precision = precision + + def pre_optimizer_step( + self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs + ) -> bool: + lambda_closure() + return True + + def backward( + self, + lightning_module: LightningModule, + closure_loss: torch.Tensor, + optimizer: torch.optim.Optimizer, + opt_idx: int, + should_accumulate: bool, + *args, + **kwargs, + ): + # todo: hack around for deepspeed engine to call backward + # Means that the lightning module backward function is never called + # This is an issue if the user overrides the backwards function + deepspeed_engine = lightning_module.trainer.model + deepspeed_engine.backward(closure_loss) + # once backward has been applied, release graph + closure_loss = closure_loss.detach() + + return closure_loss + + def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)): + """ + DeepSpeed handles clipping gradients via the training type plugin. + """ + pass diff --git a/pytorch_lightning/plugins/training_type/__init__.py b/pytorch_lightning/plugins/training_type/__init__.py index a5a644fc6568c..b73c6351de181 100644 --- a/pytorch_lightning/plugins/training_type/__init__.py +++ b/pytorch_lightning/plugins/training_type/__init__.py @@ -1,6 +1,7 @@ from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py new file mode 100644 index 0000000000000..6fcc92fd26451 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -0,0 +1,215 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from pathlib import Path +from types import SimpleNamespace +from typing import Callable, List, Tuple, Union + +import torch + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.apply_func import move_float_tensors_to_half +from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE +from pytorch_lightning.utilities.seed import seed_everything + +if _DEEPSPEED_AVAILABLE: + import deepspeed +else: + deepspeed = None + + +class LightningDeepSpeedModule(_LightningModuleWrapperBase): + + def __init__(self, pl_module: LightningModule, precision: int): + super().__init__(pl_module) + self.module = pl_module + self.precision = precision + + def forward(self, *inputs, **kwargs): + if self.precision == 16: + inputs = move_float_tensors_to_half(inputs) + return super().forward(*inputs, **kwargs) + + +class DeepSpeedPlugin(DDPPlugin): + distributed_backend = "deepspeed" + + def __init__( + self, + config: Union[Path, str, dict], + logging_level: int = logging.WARN, + num_nodes=1, + parallel_devices: List[torch.device] = None, + cluster_environment: ClusterEnvironment = None, + ) -> None: + super().__init__( + parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment + ) + if isinstance(config, str) or isinstance(config, Path): + with open(config) as f: + self.config = json.load(f) + else: + self.config = config + self._config_initialized = False + deepspeed.utils.logging.logger.setLevel(logging_level) + + def pre_training(self): + self.set_world_ranks() + self.init_ddp_connection(self.global_rank, self.world_size) + + self.init_deepspeed() + + # TODO: check if needed + seed = os.environ.get("PL_GLOBAL_SEED") + if seed is not None: + seed_everything(int(seed)) + + # set warning rank + rank_zero_only.rank = self.global_rank + + # set the ranks and devices + self.dist.rank = self.global_rank + self.dist.device = self.root_device + + # move the model to the correct device + self.model_to_device() + self.barrier() + + def init_deepspeed(self): + if not self._config_initialized: + self._format_config() + self._config_initialized = True + + precision = self.lightning_module.trainer.accelerator_backend.precision + model_parameters = filter(lambda p: p.requires_grad, self.model.parameters()) + model, optimizer, _, lr_scheduler = deepspeed.initialize( + args=SimpleNamespace(local_rank=self.local_rank), + model=LightningDeepSpeedModule(pl_module=self.model, precision=precision), + model_parameters=model_parameters, + config_params=self.config, + ) + trainer = self.lightning_module.trainer + if self.lightning_module.training: + trainer.optimizers = [optimizer] + trainer.lr_schedulers = self.configure_scheduler(lr_scheduler) + trainer.convert_to_lightning_optimizers() + self.model = model + + def configure_scheduler(self, lr_scheduler): + # this duplicates the defaults from init_optimizers + scheduler = { + 'scheduler': lr_scheduler, + 'name': None, # no custom name + 'interval': 'epoch', # after epoch is over + 'frequency': 1, # every epoch/batch + 'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler + 'monitor': None, # value to monitor for ReduceLROnPlateau + 'strict': True, # enforce that the monitor exists for ReduceLROnPlateau + } + return [scheduler] + + @property + def lightning_module(self): + # the model may not be wrapped with DeepEngine & LightningDeepSpeedModule if calling this too early + module = getattr(self.model, "module", self.model) + return module.module if isinstance(module, LightningDeepSpeedModule) else module + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) + return distributed_sampler_kwargs + + def init_optimizers(self, trainer: "Trainer", model: LightningModule) -> Tuple[List, List, List]: + # Skip initializing optimizers as DeepSpeed handles optimizers via config. + # User may have specified config options instead in configure_optimizers, but this is handled + # via `_format_config` + return [], [], [] # empty optimizers, schedulers and frequencies + + def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): + self.model.step(**kwargs) + + def _format_config(self): + if not self.config: + raise MisconfigurationException( + "To use DeepSpeed you must pass in a DeepSpeed config dictionary, or path to a json config." + "todo: Doc Link." + ) + self._format_optimizer_config() + self._format_batch_size_grad_accum_config() + self._format_precision_config() + + def _format_optimizer_config(self): + if ("optimizer" not in self.config) or ("scheduler" not in self.config): + self.optimizer, self.scheduler = self.model.configure_optimizers() + + if not (isinstance(self.optimizer, dict) or isinstance(self.scheduler, dict)): + raise MisconfigurationException( + "If you have not specified an optimizer or scheduler within the DeepSpeed config " + "then you must return a dict from `configure_optimizers` within the LightningModule. " + "See x for more information." + ) + + if not len(self.optimizer) == 1 or len(self.scheduler) == 1: + raise MisconfigurationException("DeepSpeed currently only supports single optimizer, single scheduler.") + + optimizer_name, optimizer_params = self.optimizer.items()[0] + scheduler_name, scheduler_params = self.scheduler.items()[0] + + self.config["optimizer"] = { + "type": optimizer_name, + "params": optimizer_params, + } + self.config["scheduler"] = { + "type": scheduler_name, + "params": scheduler_params, + } + + def _format_batch_size_grad_accum_config(self): + if "train_batch_size" in self.config or "train_micro_batch_size_per_gpu" in self.config: + raise MisconfigurationException( + "Within the DeepSpeed config, do not set train_batch_size or train_micro_batch_size_per_gpu " + "as these will be passed from the data-loader." + ) + if "gradient_accumulation_steps" in self.config: + raise MisconfigurationException( + "Within the DeepSpeed config, do not set gradient_accumulation_steps " + "as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer." + ) + self.config["train_micro_batch_size_per_gpu"] = self.lightning_module.train_dataloader().batch_size + self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches + if "gradient_clipping" not in self.config: + self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val + + def _format_precision_config(self): + + amp_type = self.lightning_module.trainer.accelerator_connector.amp_type + amp_level = self.lightning_module.trainer.accelerator_connector.amp_level + precision = self.lightning_module.trainer.accelerator_connector.precision + if precision == 16: + if "amp" not in self.config and amp_type == AMPType.NATIVE: + self.config["fp16"] = {"enabled": True} + elif "apex" not in self.config and amp_type == AMPType.APEX: + self.config["amp"] = { + "enabled": True, + "opt_level": amp_level, + } diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index db0e390c4b03e..c1ef243b1b414 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from abc import ABC, abstractmethod -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch from torch.nn import Module @@ -144,3 +143,9 @@ def test_step_end(self, output): def on_save(self, checkpoint: dict) -> dict: return checkpoint + + def init_optimizers(self, trainer: "Trainer", model: LightningModule): + return trainer.init_optimizers(model) + + def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): + optimizer.step(closure=lambda_closure, **kwargs) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4f9c5d4f5e19f..b321d54ca924e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -135,6 +135,7 @@ def __init__( move_metrics_to_cpu: bool = False, enable_pl_optimizer: bool = None, # todo: remove in v1.3 multiple_trainloader_mode: str = 'max_size_cycle', + deepspeed_config: Optional[Union[Path, str, Dict]] = None, ): r""" Customize every aspect of training via flags @@ -293,6 +294,8 @@ def __init__( In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed, and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets reload when reaching the minimum length of datasets. + + deepspeed_config: Path to deepspeed config or dict when using deepspeed accelerator backend. """ super().__init__() self._running_stage = None @@ -307,8 +310,8 @@ def __init__( self.optimizer_connector = OptimizerConnector(self) self.accelerator_connector = BackendConnector( - num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, sync_batchnorm, benchmark, - replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins + num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, sync_batchnorm, + deepspeed_config, benchmark, replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins ) self.logger_connector = LoggerConnector(self, log_gpu_memory) self.model_connector = ModelConnector(self) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 889ed96f43679..cf3aa06f305b8 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -26,6 +26,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401 _APEX_AVAILABLE, _BOLTS_AVAILABLE, + _DEEPSPEED_AVAILABLE, _FAIRSCALE_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE, _GROUP_AVAILABLE, diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 2f7425bf3beb0..3b2f6f6683b7d 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -164,3 +164,22 @@ def convert_to_tensors(data, device: torch.device = None): for src_dtype, conversion_func in CONVERSION_DTYPES: data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device)) return data + + +def move_float_tensors_to_half(batch: Any): + """ + Transfers all float32 tensors in batch to half precision. + Args: + batch: A float32 tensor or collection of float32 tensors. + See :func:`apply_to_collection` for a list of supported collection types. + Return: + the same collection but with all contained float32 tensors converted to half precision. + """ + + def batch_to(data): + return data.half() + + dtypes = [torch.FloatTensor, torch.cuda.FloatTensor] + for dtype in dtypes: + batch = apply_to_collection(batch, dtype=dtype, function=batch_to) + return batch diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index c7796b433f1ed..3e4add4fb68d1 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -62,6 +62,7 @@ class DistributedType(LightningEnum): DDP = 'ddp' DDP2 = 'ddp2' DDP_SPAWN = 'ddp_spawn' + DEEPSPEED = 'deepspeed' HOROVOD = 'horovod' DDP_SHARDED = 'ddp_sharded' DDP_SHARDED_SPAWN = 'ddp_sharded_spawn' diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 4d1b38eaf5949..f29db7a5c31d0 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -55,6 +55,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_QUANTIZE_AVAILABLE = _module_available('torch.ops.quantized') _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available('pl_bolts') +_DEEPSPEED_AVAILABLE = platform.system() != 'Windows' and _module_available('deepspeed') and _module_available('mpi4py') _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel') _FAIRSCALE_PIPE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and _compare_version("fairscale", operator.le, "0.1.3") _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group') diff --git a/requirements/extra.txt b/requirements/extra.txt index 1654f050398a9..47cfc67a4a5a6 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -8,3 +8,5 @@ onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip +deepspeed>=0.3.10 +mpi4py>=3.0.3 diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py new file mode 100644 index 0000000000000..c04e917eaf3b0 --- /dev/null +++ b/tests/plugins/test_deepspeed_plugin.py @@ -0,0 +1,232 @@ +import os +import platform + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin +from pytorch_lightning.utilities import _APEX_AVAILABLE, _DEEPSPEED_AVAILABLE, _NATIVE_AMP_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel + + +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_deepspeed_choice(tmpdir): + """ + Test to ensure that plugin is correctly chosen + """ + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + accelerator='deepspeed', + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_deepspeed_plugin(tmpdir): + """ + Test to ensure that the plugin can be passed directly, and parallel devices is correctly set. + """ + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) + assert trainer.accelerator_backend.training_type_plugin.parallel_devices == [torch.device('cpu')] + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + plugins=[DeepSpeedPlugin(config={})], + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") +def test_deepspeed_amp_choice(tmpdir): + """ + Test to ensure precision plugin is also correctly chosen. DeepSpeed handles precision via + Custom DeepSpeedPrecisionPlugin + """ + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) + assert isinstance(trainer.accelerator_backend.precision_plugin, DeepSpeedPrecisionPlugin) + assert trainer.accelerator_backend.precision_plugin.precision == 16 + raise SystemExit() + + model = BoringModel() + trainer = Trainer(fast_dev_run=True, accelerator='deepspeed', callbacks=[CB()], amp_backend='native', precision=16) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif(not _APEX_AVAILABLE, reason="Requires Apex") +def test_deepspeed_apex_choice(tmpdir): + """ + Test to ensure precision plugin is also correctly chosen. DeepSpeed handles precision via + Custom DeepSpeedPrecisionPlugin + """ + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) + assert isinstance(trainer.accelerator_backend.precision_plugin, DeepSpeedPrecisionPlugin) + assert trainer.accelerator_backend.precision_plugin.precision == 16 + raise SystemExit() + + model = BoringModel() + trainer = Trainer(fast_dev_run=True, accelerator='deepspeed', callbacks=[CB()], amp_backend='apex', precision=16) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_invalid_deepspeed_without_config(tmpdir): + """ + Test to ensure if a DeepSpeed config is not provided, we throw an exception. + """ + model = BoringModel() + trainer = Trainer( + accelerator='deepspeed', + gpus=1, + fast_dev_run=True, + ) + + with pytest.raises( + MisconfigurationException, + match="To use DeepSpeed you must pass in a DeepSpeed config dictionary, or path to a json config." + ): + trainer.fit(model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) +def test_deepspeed(tmpdir): + """ + Test to ensure deepspeed works correctly with a valid config object, + and saves the model weights to load correctly. + """ + deepspeed_config = { + "optimizer": { + "type": "Adam", + "params": { + "lr": 3e-5, + "betas": [0.998, 0.999], + "eps": 1e-5, + "weight_decay": 1e-9, + }, + }, + 'scheduler': { + "type": "WarmupLR", + "params": { + "last_batch_iteration": -1, + "warmup_min_lr": 0, + "warmup_max_lr": 3e-5, + "warmup_num_steps": 100, + } + } + } + model = BoringModel() + trainer = Trainer(accelerator='deepspeed', gpus=1, fast_dev_run=True, deepspeed_config=deepspeed_config) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(orig_param, trained_model_param) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) +def test_deepspeed_offload_zero_multigpu(tmpdir): + """ + Test to ensure that zero offload with multiple GPUs works correctly. + """ + deepspeed_config = { + "optimizer": { + "type": "Adam", + "params": { + "lr": 3e-5, + "betas": [0.998, 0.999], + "eps": 1e-5, + "weight_decay": 1e-9, + }, + }, + 'scheduler': { + "type": "WarmupLR", + "params": { + "last_batch_iteration": -1, + "warmup_min_lr": 0, + "warmup_max_lr": 3e-5, + "warmup_num_steps": 100, + } + }, + "zero_allow_untested_optimizer": False, + "zero_optimization": { + "stage": 2, + "cpu_offload": True, + "contiguous_gradients": True, + "overlap_comm": True + } + } + model = BoringModel() + trainer = Trainer( + accelerator='deepspeed', + gpus=2, + fast_dev_run=True, + deepspeed_config=deepspeed_config, + precision=16, + ) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + # carry out the check only on rank 0 + if trainer.global_rank == 0: + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + saved_model = saved_model.float() + model = model.float() + # Assert model parameters are identical after loading + for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(orig_param, trained_model_param) From 14c7b612c946033409bfc6b64c7dc6629efaeb2a Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sun, 14 Feb 2021 13:36:35 +0000 Subject: [PATCH 02/58] Address code review --- docs/source/advanced/multi_gpu.rst | 2 +- .../accelerators/accelerator_connector.py | 12 +----- .../plugins/precision/deepspeed_precision.py | 10 ++++- .../plugins/training_type/deepspeed.py | 39 ++++++++++++------- pytorch_lightning/utilities/apply_func.py | 19 --------- 5 files changed, 36 insertions(+), 46 deletions(-) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 0fbe0eea6236f..c2be3c68738ad 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -683,7 +683,7 @@ To use DeepSpeed, you first need to install DeepSpeed using the commands below. pip install mpi4py deepspeed -If you run into an issue with the install or later in training, ensure that the CUDA version of the pytorch you've installed matches your locally installed CUDA (you can see which one has been recognized by running ``nvidia-smi``). +If you run into an issue with the install or later in training, ensure that the CUDA version of the pytorch you've installed matches your locally installed CUDA (you can see which one has been recognized by running ``nvcc --version``). Additionally if you run into any issues installing m4py, ensure you have openmpi installed using ``sudo apt install libopenmpi-dev`` or ``brew install mpich`` before running ``pip install mpi4py``. Below we show an example of running `ZeRO Offload `_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption. diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 37731b4e128fc..3ca41c30c0c57 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -298,7 +298,8 @@ def is_using_torchelastic(self) -> bool: return te_flags_passed def select_precision_plugin(self) -> PrecisionPlugin: - self._set_precision_type() + # set precision type + self.amp_type = AMPType.from_str(self.amp_type) if self._distrib_type == DistributedType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin): return DeepSpeedPrecisionPlugin(self.precision) @@ -349,15 +350,6 @@ def select_precision_plugin(self) -> PrecisionPlugin: else: raise NotImplementedError("We only support precisions 32 and 16!") - def _set_precision_type(self): - if self.precision == 32: - self.amp_type = None - elif self.precision == 16: - if self.amp_type == 'amp': - self.amp_type = AMPType.NATIVE - elif self.amp_type == 'apex': - self.amp_type = AMPType.APEX - def select_training_type_plugin(self): if self.use_ddp2: plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment) diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index d238d64e3cd10..f4badd98a617d 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -16,8 +16,14 @@ def __init__(self, precision): def pre_optimizer_step( self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs ) -> bool: + # DeepSpeed not support closures. lambda_closure() - return True + + if not pl_module.automatic_optimization: + pl_module.trainer.call_hook("on_after_backward") + optimizer.step() + + return False def backward( self, @@ -33,7 +39,7 @@ def backward( # Means that the lightning module backward function is never called # This is an issue if the user overrides the backwards function deepspeed_engine = lightning_module.trainer.model - deepspeed_engine.backward(closure_loss) + deepspeed_engine.backward(closure_loss, **kwargs) # once backward has been applied, release graph closure_loss = closure_loss.detach() diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 6fcc92fd26451..01d01e16b3a3b 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -17,7 +17,7 @@ import os from pathlib import Path from types import SimpleNamespace -from typing import Callable, List, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch @@ -26,7 +26,7 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.apply_func import move_float_tensors_to_half +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE @@ -47,9 +47,18 @@ def __init__(self, pl_module: LightningModule, precision: int): def forward(self, *inputs, **kwargs): if self.precision == 16: - inputs = move_float_tensors_to_half(inputs) + inputs = self._move_float_tensors_to_half(inputs) + return super().forward(*inputs, **kwargs) + def _move_float_tensors_to_half(self, batch: Any): + + def batch_to(data): + return data.half() + + batch = apply_to_collection(batch, (torch.FloatTensor, torch.cuda.FloatTensor), function=batch_to) + return batch + class DeepSpeedPlugin(DDPPlugin): distributed_backend = "deepspeed" @@ -58,16 +67,21 @@ def __init__( self, config: Union[Path, str, dict], logging_level: int = logging.WARN, - num_nodes=1, - parallel_devices: List[torch.device] = None, - cluster_environment: ClusterEnvironment = None, + num_nodes: int = 1, + parallel_devices: Optional[List[torch.device]] = None, + cluster_environment: Optional[ClusterEnvironment] = None, ) -> None: super().__init__( parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment ) if isinstance(config, str) or isinstance(config, Path): - with open(config) as f: - self.config = json.load(f) + if os.path.exists(config): + with open(config) as f: + self.config = json.load(f) + else: + raise MisconfigurationException( + f"You passed in a path to a DeepSpeed config but the path does not exist: {config}" + ) else: self.config = config self._config_initialized = False @@ -90,9 +104,6 @@ def pre_training(self): # set the ranks and devices self.dist.rank = self.global_rank self.dist.device = self.root_device - - # move the model to the correct device - self.model_to_device() self.barrier() def init_deepspeed(self): @@ -164,9 +175,9 @@ def _format_optimizer_config(self): if not (isinstance(self.optimizer, dict) or isinstance(self.scheduler, dict)): raise MisconfigurationException( - "If you have not specified an optimizer or scheduler within the DeepSpeed config " - "then you must return a dict from `configure_optimizers` within the LightningModule. " - "See x for more information." + "If you have not specified an optimizer or scheduler within the DeepSpeed config" + " then you must return a dict from `configure_optimizers` within the LightningModule." + " See x for more information." ) if not len(self.optimizer) == 1 or len(self.scheduler) == 1: diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 3b2f6f6683b7d..2f7425bf3beb0 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -164,22 +164,3 @@ def convert_to_tensors(data, device: torch.device = None): for src_dtype, conversion_func in CONVERSION_DTYPES: data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device)) return data - - -def move_float_tensors_to_half(batch: Any): - """ - Transfers all float32 tensors in batch to half precision. - Args: - batch: A float32 tensor or collection of float32 tensors. - See :func:`apply_to_collection` for a list of supported collection types. - Return: - the same collection but with all contained float32 tensors converted to half precision. - """ - - def batch_to(data): - return data.half() - - dtypes = [torch.FloatTensor, torch.cuda.FloatTensor] - for dtype in dtypes: - batch = apply_to_collection(batch, dtype=dtype, function=batch_to) - return batch From e8ab7fd0b71183bb5c7d5d647f78cac1356403d5 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 15 Feb 2021 10:30:02 +0000 Subject: [PATCH 03/58] Move static method outside of function --- pytorch_lightning/plugins/training_type/deepspeed.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 01d01e16b3a3b..098819456bc8b 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -51,12 +51,12 @@ def forward(self, *inputs, **kwargs): return super().forward(*inputs, **kwargs) - def _move_float_tensors_to_half(self, batch: Any): - - def batch_to(data): - return data.half() + @staticmethod + def batch_to(data): + return data.half() - batch = apply_to_collection(batch, (torch.FloatTensor, torch.cuda.FloatTensor), function=batch_to) + def _move_float_tensors_to_half(self, batch: Any): + batch = apply_to_collection(batch, (torch.FloatTensor, torch.cuda.FloatTensor), function=self.batch_to) return batch From 5b1e0913de2d06231d53e8f22c2bf82087674c9a Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 15 Feb 2021 15:19:37 +0000 Subject: [PATCH 04/58] Fixes --- docs/source/advanced/multi_gpu.rst | 33 +++-- .../accelerators/accelerator_connector.py | 3 - .../plugins/precision/deepspeed_precision.py | 5 +- .../plugins/training_type/deepspeed.py | 23 ++- pytorch_lightning/trainer/trainer.py | 7 +- pytorch_lightning/utilities/imports.py | 2 +- requirements/extra.txt | 3 +- tests/plugins/test_deepspeed_plugin.py | 135 +++++++++--------- 8 files changed, 113 insertions(+), 98 deletions(-) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index c2be3c68738ad..e80261d9605a3 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -670,12 +670,13 @@ Internally we re-initialize your optimizers and shard them across your machines ---------- -.. _deepspeed: +.. _deep_speed: DeepSpeed ^^^^^^^^^ -`DeepSpeed `_ offers additional CUDA deep learning training optimizations, similar to `FairScale `_. DeepSpeed offers lower level training optimizations, and useful memory optimized optimizers such as 1-bit Adam. -We recommend using DeepSpeed in environments where speed and memory optimizations are important (such as training large billion parameter models), and where sacrificing flexibility as a tradeoff is acceptable. In addition, we recommend trying :doc:`sharded` first before trying DeepSpeed's further optimizations. +`DeepSpeed `_ offers additional CUDA deep learning training optimizations, similar to `FairScale `_. DeepSpeed offers lower level training optimizations, and useful memory optimized optimizers such as `1-bit Adam `_. +Using the plugin, we were able to train model sizes of 10 Billion+ parameters and above, with a lot of useful information in this `issue `_ and DeepSpeed `docs `_. +We recommend using DeepSpeed in environments where speed and memory optimizations are important (such as training large billion parameter models), and where sacrificing flexibility as a tradeoff is acceptable. In addition, we recommend trying :ref:`sharded` first before trying DeepSpeed's further optimizations. To use DeepSpeed, you first need to install DeepSpeed using the commands below. @@ -689,11 +690,16 @@ Additionally if you run into any issues installing m4py, ensure you have openmpi Below we show an example of running `ZeRO Offload `_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption. For even more speed benefit, they offer an optimized CPU version of ADAM to run the offloaded computation, which is faster than the standard PyTorch implementation. -DeepSpeed requires that the optimizers and schedulers are defined within a config file. We've included the config to enable ZeRO-Offload. +DeepSpeed requires that the optimizers and schedulers are defined within a config file. +We've included the config to enable ZeRO-Offload, as well as set the bucket sizes to reasonable defaults for low VRAM GPUs (less than 7GB). + +.. note:: + To use ZeRO-Offload to train large models, you must use ``precision=16`` or set precision via `the DeepSpeed config as seen `_. .. code-block:: python from pytorch_lightning import Trainer + from pytorch_lightning.plugins import DeepSpeedPlugin deepspeed_config = { "optimizer": { @@ -715,15 +721,17 @@ DeepSpeed requires that the optimizers and schedulers are defined within a confi } }, "zero_optimization": { - "stage": 2, - "cpu_offload": True, - "contiguous_gradients": True, - "overlap_comm": True + "stage": 2, # Enable Stage 2 ZeRO (Optimizer/Gradient state partitioning) + "cpu_offload": True, # Enable Offloading optimizer state/calculation to the host CPU + "contiguous_gradients": True, # Copy gradients to a contiguous memory buffer to reduce fragmentation. + "overlap_comm": True # Overlap the reduce operation of gradients with the backwards pass for speed optimization. + "allgather_bucket_size": 2e8, # Controls the number of elements to all gather at once. + "reduce_bucket_size": 2e8, # Controls the number of elements we reduce/allreduce at once. } } model = MyModel() - trainer = Trainer(accelerator='deepspeed', gpus=4, deepspeed_config=deepspeed_config, precision=16) # zero offload requires mixed precision + trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(deepspeed_config), precision=16) trainer.fit(model) We also support taking the config as a json formatted file. @@ -731,11 +739,16 @@ We also support taking the config as a json formatted file. .. code-block:: python from pytorch_lightning import Trainer + from pytorch_lightning.plugins import DeepSpeedPlugin model = MyModel() - trainer = Trainer(accelerator='deepspeed', gpus=4, deepspeed_config="/path/to/deepspeed_config.json", precision=16) # zero offload requires mixed precision + trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin("/path/to/deepspeed_config.json"), precision=16) trainer.fit(model) +.. note:: + We suggest tuning the ``allgather_bucket_size`` parameter and ``reduce_bucket_size`` parameter to find optimum parameters. Larger values will be more effecient in terms of throughput time, but will require more memory. + DeepSpeed allocates a reduce buffer size `multiplied by 4.5x `_ so take that into consideration when tweaking the parameter. + As reference, a reduce buffer size of 2e8 means you'll allocate roughly 3.6GB of VRAM for the buffer. ---------- diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 3ca41c30c0c57..a6abb88210b2f 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -73,7 +73,6 @@ def __init__( gpus, num_nodes, sync_batchnorm, - deepspeed_config, benchmark, replace_sampler_ddp, deterministic, @@ -93,7 +92,6 @@ def __init__( self.gpus = gpus self.num_nodes = num_nodes self.sync_batchnorm = sync_batchnorm - self.deepspeed_config = deepspeed_config self.benchmark = benchmark self.replace_sampler_ddp = replace_sampler_ddp self.deterministic = deterministic @@ -355,7 +353,6 @@ def select_training_type_plugin(self): plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment) elif self.use_ddp and self.use_deepspeed: plugin = DeepSpeedPlugin( - config=self.deepspeed_config, num_nodes=self.num_nodes, cluster_environment=self.select_cluster_environment(), parallel_devices=self.parallel_devices diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index f4badd98a617d..72ca4a526df6d 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -16,12 +16,13 @@ def __init__(self, precision): def pre_optimizer_step( self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs ) -> bool: + deepspeed_engine = pl_module.trainer.model # DeepSpeed not support closures. lambda_closure() - if not pl_module.automatic_optimization: + if pl_module.automatic_optimization: pl_module.trainer.call_hook("on_after_backward") - optimizer.step() + deepspeed_engine.step() return False diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 098819456bc8b..b66dc46e980ca 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -62,10 +62,11 @@ def _move_float_tensors_to_half(self, batch: Any): class DeepSpeedPlugin(DDPPlugin): distributed_backend = "deepspeed" + DEEPSPEED_ENV_VAR = "DEEPSPEED_CONFIG_PATH" def __init__( self, - config: Union[Path, str, dict], + config: Optional[Union[Path, str, dict]] = None, logging_level: int = logging.WARN, num_nodes: int = 1, parallel_devices: Optional[List[torch.device]] = None, @@ -74,18 +75,28 @@ def __init__( super().__init__( parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment ) + self.config = self._load_config(config) + self._config_initialized = False + deepspeed.utils.logging.logger.setLevel(logging_level) + + def _load_config(self, config): + if config is None: + if self.DEEPSPEED_ENV_VAR not in os.environ: + raise MisconfigurationException( + f"You did not pass a DeepSpeed config object or path for DeepSpeed. This can be passed" + f" via instantiating the `DeepSpeedPlugin` object, or by the DEEPSPEED_CONFIG_PATH env variable." + f" see x for more information." + ) + config = os.environ.get(self.DEEPSPEED_ENV_VAR) if isinstance(config, str) or isinstance(config, Path): if os.path.exists(config): with open(config) as f: - self.config = json.load(f) + config = json.load(f) else: raise MisconfigurationException( f"You passed in a path to a DeepSpeed config but the path does not exist: {config}" ) - else: - self.config = config - self._config_initialized = False - deepspeed.utils.logging.logger.setLevel(logging_level) + return config def pre_training(self): self.set_world_ranks() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b321d54ca924e..4f9c5d4f5e19f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -135,7 +135,6 @@ def __init__( move_metrics_to_cpu: bool = False, enable_pl_optimizer: bool = None, # todo: remove in v1.3 multiple_trainloader_mode: str = 'max_size_cycle', - deepspeed_config: Optional[Union[Path, str, Dict]] = None, ): r""" Customize every aspect of training via flags @@ -294,8 +293,6 @@ def __init__( In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed, and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets reload when reaching the minimum length of datasets. - - deepspeed_config: Path to deepspeed config or dict when using deepspeed accelerator backend. """ super().__init__() self._running_stage = None @@ -310,8 +307,8 @@ def __init__( self.optimizer_connector = OptimizerConnector(self) self.accelerator_connector = BackendConnector( - num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, sync_batchnorm, - deepspeed_config, benchmark, replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins + num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, sync_batchnorm, benchmark, + replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins ) self.logger_connector = LoggerConnector(self, log_gpu_memory) self.model_connector = ModelConnector(self) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index f29db7a5c31d0..4e98f16b3c997 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -55,7 +55,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_QUANTIZE_AVAILABLE = _module_available('torch.ops.quantized') _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available('pl_bolts') -_DEEPSPEED_AVAILABLE = platform.system() != 'Windows' and _module_available('deepspeed') and _module_available('mpi4py') +_DEEPSPEED_AVAILABLE = platform.system() != 'Windows' and _module_available('deepspeed') _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel') _FAIRSCALE_PIPE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and _compare_version("fairscale", operator.le, "0.1.3") _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group') diff --git a/requirements/extra.txt b/requirements/extra.txt index 47cfc67a4a5a6..f37fda47ee0a0 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -8,5 +8,4 @@ onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip -deepspeed>=0.3.10 -mpi4py>=3.0.3 +https://github.com/microsoft/DeepSpeed/archive/ec8b1cb0a0a5752bba029da4bdc91616c0f5bec7.zip # TODO: move to DeepSpeed release version diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index c04e917eaf3b0..1f16b316b7a50 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -1,3 +1,4 @@ +import json import os import platform @@ -11,28 +12,31 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel +PRETEND_N_OF_GPUS = 1 -@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") -def test_deepspeed_choice(tmpdir): - """ - Test to ensure that plugin is correctly chosen - """ - class CB(Callback): - - def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) - raise SystemExit() - - model = BoringModel() - trainer = Trainer( - fast_dev_run=True, - accelerator='deepspeed', - callbacks=[CB()], - ) - - with pytest.raises(SystemExit): - trainer.fit(model) +@pytest.fixture +def deepspeed_config(): + return { + "optimizer": { + "type": "Adam", + "params": { + "lr": 3e-5, + "betas": [0.998, 0.999], + "eps": 1e-5, + "weight_decay": 1e-9, + }, + }, + 'scheduler': { + "type": "WarmupLR", + "params": { + "last_batch_iteration": -1, + "warmup_min_lr": 0, + "warmup_max_lr": 3e-5, + "warmup_num_steps": 100, + } + } + } @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") @@ -76,7 +80,9 @@ def on_fit_start(self, trainer, pl_module): raise SystemExit() model = BoringModel() - trainer = Trainer(fast_dev_run=True, accelerator='deepspeed', callbacks=[CB()], amp_backend='native', precision=16) + trainer = Trainer( + fast_dev_run=True, plugins=[DeepSpeedPlugin(config={})], callbacks=[CB()], amp_backend='native', precision=16 + ) with pytest.raises(SystemExit): trainer.fit(model) @@ -99,7 +105,9 @@ def on_fit_start(self, trainer, pl_module): raise SystemExit() model = BoringModel() - trainer = Trainer(fast_dev_run=True, accelerator='deepspeed', callbacks=[CB()], amp_backend='apex', precision=16) + trainer = Trainer( + fast_dev_run=True, plugins=[DeepSpeedPlugin(config={})], callbacks=[CB()], amp_backend='apex', precision=16 + ) with pytest.raises(SystemExit): trainer.fit(model) @@ -114,9 +122,9 @@ def test_invalid_deepspeed_without_config(tmpdir): """ model = BoringModel() trainer = Trainer( - accelerator='deepspeed', - gpus=1, fast_dev_run=True, + plugins=[DeepSpeedPlugin(config={})], + gpus=1, ) with pytest.raises( @@ -126,39 +134,48 @@ def test_invalid_deepspeed_without_config(tmpdir): trainer.fit(model) +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_deepspeed_with_invalid_config_path(tmpdir): + """ + Test to ensure if we pass an invalid config path we throw an exception. + """ + + with pytest.raises( + MisconfigurationException, match="You passed in a path to a DeepSpeed config but the path does not exist" + ): + DeepSpeedPlugin(config='invalid_path.json') + + +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_deepspeed_with_env_path(tmpdir, monkeypatch, deepspeed_config): + """ + Test to ensure if we pass an env variable, we load the config from the path. + """ + config_path = os.path.join(tmpdir, 'temp.json') + with open(config_path, 'w') as f: + f.write(json.dumps(deepspeed_config)) + monkeypatch.setenv("DEEPSPEED_CONFIG_PATH", config_path) + plugin = DeepSpeedPlugin() + assert plugin.config == deepspeed_config + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") @pytest.mark.skipif( not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" ) -def test_deepspeed(tmpdir): +def test_deepspeed(tmpdir, deepspeed_config): """ Test to ensure deepspeed works correctly with a valid config object, and saves the model weights to load correctly. """ - deepspeed_config = { - "optimizer": { - "type": "Adam", - "params": { - "lr": 3e-5, - "betas": [0.998, 0.999], - "eps": 1e-5, - "weight_decay": 1e-9, - }, - }, - 'scheduler': { - "type": "WarmupLR", - "params": { - "last_batch_iteration": -1, - "warmup_min_lr": 0, - "warmup_max_lr": 3e-5, - "warmup_num_steps": 100, - } - } - } model = BoringModel() - trainer = Trainer(accelerator='deepspeed', gpus=1, fast_dev_run=True, deepspeed_config=deepspeed_config) + trainer = Trainer( + plugins=[DeepSpeedPlugin(deepspeed_config)], + gpus=1, + fast_dev_run=True, + ) trainer.fit(model) @@ -178,30 +195,12 @@ def test_deepspeed(tmpdir): @pytest.mark.skipif( not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" ) -def test_deepspeed_offload_zero_multigpu(tmpdir): +def test_deepspeed_offload_zero_multigpu(tmpdir, deepspeed_config): """ Test to ensure that zero offload with multiple GPUs works correctly. """ deepspeed_config = { - "optimizer": { - "type": "Adam", - "params": { - "lr": 3e-5, - "betas": [0.998, 0.999], - "eps": 1e-5, - "weight_decay": 1e-9, - }, - }, - 'scheduler': { - "type": "WarmupLR", - "params": { - "last_batch_iteration": -1, - "warmup_min_lr": 0, - "warmup_max_lr": 3e-5, - "warmup_num_steps": 100, - } - }, - "zero_allow_untested_optimizer": False, + **deepspeed_config, "zero_allow_untested_optimizer": False, "zero_optimization": { "stage": 2, "cpu_offload": True, @@ -211,13 +210,11 @@ def test_deepspeed_offload_zero_multigpu(tmpdir): } model = BoringModel() trainer = Trainer( - accelerator='deepspeed', + plugins=[DeepSpeedPlugin(deepspeed_config)], gpus=2, fast_dev_run=True, - deepspeed_config=deepspeed_config, precision=16, ) - trainer.fit(model) checkpoint_path = os.path.join(tmpdir, 'model.pt') From d6d90be95e0f282e38fc7ccc7ef79930f253d6e1 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 15 Feb 2021 15:29:53 +0000 Subject: [PATCH 05/58] Add missing annotation --- pytorch_lightning/accelerators/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index a6abb88210b2f..96489e578963c 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -348,7 +348,7 @@ def select_precision_plugin(self) -> PrecisionPlugin: else: raise NotImplementedError("We only support precisions 32 and 16!") - def select_training_type_plugin(self): + def select_training_type_plugin(self) -> TrainingTypePlugin: if self.use_ddp2: plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment) elif self.use_ddp and self.use_deepspeed: From ab4efdf52cb6fb8a00fffb88d6ad39d4d49867ec Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 15 Feb 2021 15:30:19 +0000 Subject: [PATCH 06/58] Remove seed setting --- pytorch_lightning/plugins/training_type/deepspeed.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index b66dc46e980ca..a9218ea472758 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -104,11 +104,6 @@ def pre_training(self): self.init_deepspeed() - # TODO: check if needed - seed = os.environ.get("PL_GLOBAL_SEED") - if seed is not None: - seed_everything(int(seed)) - # set warning rank rank_zero_only.rank = self.global_rank From bffb9166d23eeba926aff54b5ef706f9a03a93b1 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 15 Feb 2021 15:33:06 +0000 Subject: [PATCH 07/58] Doc changes --- docs/source/advanced/multi_gpu.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index e80261d9605a3..7e300fbe26cd7 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -695,6 +695,7 @@ We've included the config to enable ZeRO-Offload, as well as set the bucket size .. note:: To use ZeRO-Offload to train large models, you must use ``precision=16`` or set precision via `the DeepSpeed config as seen `_. + All compatible arguments can be seen `in the DeepSpeed Config Json docs `_. .. code-block:: python @@ -734,6 +735,7 @@ We've included the config to enable ZeRO-Offload, as well as set the bucket size trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(deepspeed_config), precision=16) trainer.fit(model) + We also support taking the config as a json formatted file. .. code-block:: python @@ -745,6 +747,7 @@ We also support taking the config as a json formatted file. trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin("/path/to/deepspeed_config.json"), precision=16) trainer.fit(model) + .. note:: We suggest tuning the ``allgather_bucket_size`` parameter and ``reduce_bucket_size`` parameter to find optimum parameters. Larger values will be more effecient in terms of throughput time, but will require more memory. DeepSpeed allocates a reduce buffer size `multiplied by 4.5x `_ so take that into consideration when tweaking the parameter. From 5c4444ddebe05a7f2681ad42549c80c94ad9f6f1 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 15 Feb 2021 20:15:17 +0000 Subject: [PATCH 08/58] Doc changes, add address reviews --- docs/source/advanced/multi_gpu.rst | 37 ++++-- .../accelerators/accelerator_connector.py | 3 +- .../plugins/training_type/deepspeed.py | 114 ++++++++++-------- pytorch_lightning/trainer/trainer.py | 3 +- tests/plugins/test_deepspeed_plugin.py | 30 +++++ 5 files changed, 122 insertions(+), 65 deletions(-) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 7e300fbe26cd7..3de32d77769a6 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -672,8 +672,12 @@ Internally we re-initialize your optimizers and shard them across your machines .. _deep_speed: -DeepSpeed -^^^^^^^^^ +DeepSpeed [EXPERIMENTAL] +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. note:: + The DeepSpeed plugin is experimental and the API is subject to change. Please create an `issue `_ if you run into any issues. + `DeepSpeed `_ offers additional CUDA deep learning training optimizations, similar to `FairScale `_. DeepSpeed offers lower level training optimizations, and useful memory optimized optimizers such as `1-bit Adam `_. Using the plugin, we were able to train model sizes of 10 Billion+ parameters and above, with a lot of useful information in this `issue `_ and DeepSpeed `docs `_. We recommend using DeepSpeed in environments where speed and memory optimizations are important (such as training large billion parameter models), and where sacrificing flexibility as a tradeoff is acceptable. In addition, we recommend trying :ref:`sharded` first before trying DeepSpeed's further optimizations. @@ -687,15 +691,15 @@ To use DeepSpeed, you first need to install DeepSpeed using the commands below. If you run into an issue with the install or later in training, ensure that the CUDA version of the pytorch you've installed matches your locally installed CUDA (you can see which one has been recognized by running ``nvcc --version``). Additionally if you run into any issues installing m4py, ensure you have openmpi installed using ``sudo apt install libopenmpi-dev`` or ``brew install mpich`` before running ``pip install mpi4py``. -Below we show an example of running `ZeRO Offload `_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption. +Below we show an example of running `ZeRO-Offload `_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption. For even more speed benefit, they offer an optimized CPU version of ADAM to run the offloaded computation, which is faster than the standard PyTorch implementation. DeepSpeed requires that the optimizers and schedulers are defined within a config file. We've included the config to enable ZeRO-Offload, as well as set the bucket sizes to reasonable defaults for low VRAM GPUs (less than 7GB). .. note:: - To use ZeRO-Offload to train large models, you must use ``precision=16`` or set precision via `the DeepSpeed config as seen `_. - All compatible arguments can be seen `in the DeepSpeed Config Json docs `_. + To use ZeRO-Offload to train large models, you must use ``precision=16`` or set precision via `the DeepSpeed config. `_. + All compatible arguments can be seen in the `DeepSpeed docs `_. .. code-block:: python @@ -724,10 +728,10 @@ We've included the config to enable ZeRO-Offload, as well as set the bucket size "zero_optimization": { "stage": 2, # Enable Stage 2 ZeRO (Optimizer/Gradient state partitioning) "cpu_offload": True, # Enable Offloading optimizer state/calculation to the host CPU - "contiguous_gradients": True, # Copy gradients to a contiguous memory buffer to reduce fragmentation. - "overlap_comm": True # Overlap the reduce operation of gradients with the backwards pass for speed optimization. - "allgather_bucket_size": 2e8, # Controls the number of elements to all gather at once. - "reduce_bucket_size": 2e8, # Controls the number of elements we reduce/allreduce at once. + "contiguous_gradients": True, # Reduce gradient fragmentation. + "overlap_comm": True # Overlap reduce/backward operation of gradients for speed. + "allgather_bucket_size": 2e8, # Number of elements to all gather at once. + "reduce_bucket_size": 2e8, # Number of elements we reduce/allreduce at once. } } @@ -736,7 +740,7 @@ We've included the config to enable ZeRO-Offload, as well as set the bucket size trainer.fit(model) -We also support taking the config as a json formatted file. +We support taking the config as a json formatted file: .. code-block:: python @@ -748,9 +752,18 @@ We also support taking the config as a json formatted file. trainer.fit(model) +You can use also use an environment variable via your Pytorch Lightning script: + +.. code-block:: bash + + DEEPSPEED_CONFIG_PATH=/path/to/deepspeed_config.json python train.py plugins=deepspeed + + .. note:: - We suggest tuning the ``allgather_bucket_size`` parameter and ``reduce_bucket_size`` parameter to find optimum parameters. Larger values will be more effecient in terms of throughput time, but will require more memory. - DeepSpeed allocates a reduce buffer size `multiplied by 4.5x `_ so take that into consideration when tweaking the parameter. + We suggest tuning the ``allgather_bucket_size`` parameter and ``reduce_bucket_size`` parameter to find optimum parameters based on your model size. Larger values will be more efficient in terms of throughput time, but will require more memory. + + DeepSpeed allocates a reduce buffer size `multiplied by 4.5x `_ so take that into consideration when tweaking the parameters. + As reference, a reduce buffer size of 2e8 means you'll allocate roughly 3.6GB of VRAM for the buffer. ---------- diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 96489e578963c..a7c12721004d7 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -34,6 +34,7 @@ DeepSpeedPrecisionPlugin, HorovodPlugin, NativeMixedPrecisionPlugin, + Plugin, PrecisionPlugin, ShardedNativeMixedPrecisionPlugin, SingleDevicePlugin, @@ -146,7 +147,7 @@ def __init__( self.replace_sampler_ddp = replace_sampler_ddp - def handle_given_plugins(self, plugins: Optional[Sequence]): + def handle_given_plugins(self, plugins: Optional[Union[Plugin, Sequence]]): plugins = plugins if plugins is not None else [] if isinstance(plugins, str): diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index a9218ea472758..313904e765981 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -20,6 +20,8 @@ from typing import Any, Callable, List, Optional, Tuple, Union import torch +from torch.nn.parallel import DistributedDataParallel +from torch.optim import Optimizer from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase @@ -27,10 +29,9 @@ from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE -from pytorch_lightning.utilities.seed import seed_everything if _DEEPSPEED_AVAILABLE: import deepspeed @@ -83,9 +84,9 @@ def _load_config(self, config): if config is None: if self.DEEPSPEED_ENV_VAR not in os.environ: raise MisconfigurationException( - f"You did not pass a DeepSpeed config object or path for DeepSpeed. This can be passed" - f" via instantiating the `DeepSpeedPlugin` object, or by the DEEPSPEED_CONFIG_PATH env variable." - f" see x for more information." + "You did not pass a DeepSpeed config object or path for DeepSpeed. This can be passed" + " via instantiating the `DeepSpeedPlugin` object, or by the DEEPSPEED_CONFIG_PATH env variable." + " See: https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html#deepspeed" ) config = os.environ.get(self.DEEPSPEED_ENV_VAR) if isinstance(config, str) or isinstance(config, Path): @@ -118,20 +119,61 @@ def init_deepspeed(self): self._config_initialized = True precision = self.lightning_module.trainer.accelerator_backend.precision + model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) + + if self.lightning_module.trainer.training: + self._initialize_deepspeed_train(model) + else: + self._initialize_deepspeed_inference(model) + + def _init_scheduler_optimizer(self): + optimizer, lightning_scheduler, optimizer_frequencies = self.lightning_module.trainer.init_optimizers( + self.lightning_module + ) + if (len(optimizer) != 1) or (lightning_scheduler is not None and len(lightning_scheduler) != 1): + raise MisconfigurationException( + "DeepSpeed currently only supports single optimizer, single optional scheduler." + ) + lightning_scheduler = lightning_scheduler[0]['scheduler'] if lightning_scheduler else None + optimizer = optimizer[0] + return optimizer, lightning_scheduler, optimizer_frequencies + + def _initialize_deepspeed_train(self, model): + optimizer, lightning_scheduler, optimizer_frequencies = None, None, None + if "optimizer" not in self.config: + rank_zero_info( + "You have not specified an optimizer or scheduler within the DeepSpeed config." + "Using `configure_optimizers` to define optimizer and scheduler." + ) + optimizer, lightning_scheduler, optimizer_frequencies = self._init_scheduler_optimizer() + model_parameters = filter(lambda p: p.requires_grad, self.model.parameters()) model, optimizer, _, lr_scheduler = deepspeed.initialize( args=SimpleNamespace(local_rank=self.local_rank), - model=LightningDeepSpeedModule(pl_module=self.model, precision=precision), + model=model, model_parameters=model_parameters, + optimizer=optimizer, + lr_scheduler=lightning_scheduler, config_params=self.config, ) + + # set optimizer for save/load, but deepspeed manages the specific optimizer logic trainer = self.lightning_module.trainer - if self.lightning_module.training: - trainer.optimizers = [optimizer] - trainer.lr_schedulers = self.configure_scheduler(lr_scheduler) - trainer.convert_to_lightning_optimizers() + trainer.optimizers = [optimizer] + trainer.convert_to_lightning_optimizers() self.model = model + def _initialize_deepspeed_inference(self, model): + # move the model to the correct device + self.model_to_device() + + self.pre_configure_ddp() + self._model = DistributedDataParallel( + model, + device_ids=self.determine_ddp_device_ids(), + **self._ddp_kwargs, + ) + def configure_scheduler(self, lr_scheduler): # this duplicates the defaults from init_optimizers scheduler = { @@ -157,9 +199,9 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs def init_optimizers(self, trainer: "Trainer", model: LightningModule) -> Tuple[List, List, List]: - # Skip initializing optimizers as DeepSpeed handles optimizers via config. + # Skip initializing optimizers here as DeepSpeed handles optimizers via config. # User may have specified config options instead in configure_optimizers, but this is handled - # via `_format_config` + # via `_initialize_deepspeed_train` return [], [], [] # empty optimizers, schedulers and frequencies def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): @@ -168,51 +210,21 @@ def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Calla def _format_config(self): if not self.config: raise MisconfigurationException( - "To use DeepSpeed you must pass in a DeepSpeed config dictionary, or path to a json config." - "todo: Doc Link." + "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config." + " See: https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html#deepspeed" ) - self._format_optimizer_config() - self._format_batch_size_grad_accum_config() + self._format_batch_size_and_grad_accum_config() self._format_precision_config() - def _format_optimizer_config(self): - if ("optimizer" not in self.config) or ("scheduler" not in self.config): - self.optimizer, self.scheduler = self.model.configure_optimizers() - - if not (isinstance(self.optimizer, dict) or isinstance(self.scheduler, dict)): - raise MisconfigurationException( - "If you have not specified an optimizer or scheduler within the DeepSpeed config" - " then you must return a dict from `configure_optimizers` within the LightningModule." - " See x for more information." - ) - - if not len(self.optimizer) == 1 or len(self.scheduler) == 1: - raise MisconfigurationException("DeepSpeed currently only supports single optimizer, single scheduler.") - - optimizer_name, optimizer_params = self.optimizer.items()[0] - scheduler_name, scheduler_params = self.scheduler.items()[0] - - self.config["optimizer"] = { - "type": optimizer_name, - "params": optimizer_params, - } - self.config["scheduler"] = { - "type": scheduler_name, - "params": scheduler_params, - } - - def _format_batch_size_grad_accum_config(self): - if "train_batch_size" in self.config or "train_micro_batch_size_per_gpu" in self.config: - raise MisconfigurationException( - "Within the DeepSpeed config, do not set train_batch_size or train_micro_batch_size_per_gpu " - "as these will be passed from the data-loader." - ) + def _format_batch_size_and_grad_accum_config(self): if "gradient_accumulation_steps" in self.config: raise MisconfigurationException( - "Within the DeepSpeed config, do not set gradient_accumulation_steps " - "as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer." + "Within the DeepSpeed config, do not set gradient_accumulation_steps" + " as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer." ) - self.config["train_micro_batch_size_per_gpu"] = self.lightning_module.train_dataloader().batch_size + if "train_micro_batch_size_per_gpu" not in self.config: + # train_micro_batch_size_per_gpu is used for logging purposes, use loader batch size + self.config["train_micro_batch_size_per_gpu"] = self.lightning_module.train_dataloader().batch_size self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches if "gradient_clipping" not in self.config: self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4f9c5d4f5e19f..33f0f270e07bc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -29,6 +29,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.step_result import Result from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.plugins import Plugin from pytorch_lightning.profiler import BaseProfiler from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import ConfigValidator @@ -127,7 +128,7 @@ def __init__( terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: bool = True, - plugins: Optional[Union[str, list]] = None, + plugins: Optional[Union[Plugin, str, list]] = None, amp_backend: str = 'native', amp_level: str = 'O2', distributed_backend: Optional[str] = None, diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 1f16b316b7a50..7d44056f4f708 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -63,6 +63,36 @@ def on_fit_start(self, trainer, pl_module): trainer.fit(model) +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_deepspeed_plugin_env(tmpdir, monkeypatch, deepspeed_config): + """ + Test to ensure that the plugin can be passed via a string with an environment variable. + """ + config_path = os.path.join(tmpdir, 'temp.json') + with open(config_path, 'w') as f: + f.write(json.dumps(deepspeed_config)) + monkeypatch.setenv("DEEPSPEED_CONFIG_PATH", config_path) + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + plugin = trainer.accelerator_backend.training_type_plugin + assert isinstance(plugin, DeepSpeedPlugin) + assert plugin.parallel_devices == [torch.device('cpu')] + assert plugin.config == deepspeed_config + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + plugins='deepspeed', + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") @pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") def test_deepspeed_amp_choice(tmpdir): From 978470cd11156f20f5c324c27578c73228f3f99b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 15 Feb 2021 20:17:28 +0000 Subject: [PATCH 09/58] Fix docs --- docs/source/advanced/multi_gpu.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 3de32d77769a6..32008857d48c8 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -678,8 +678,8 @@ DeepSpeed [EXPERIMENTAL] .. note:: The DeepSpeed plugin is experimental and the API is subject to change. Please create an `issue `_ if you run into any issues. -`DeepSpeed `_ offers additional CUDA deep learning training optimizations, similar to `FairScale `_. DeepSpeed offers lower level training optimizations, and useful memory optimized optimizers such as `1-bit Adam `_. -Using the plugin, we were able to train model sizes of 10 Billion+ parameters and above, with a lot of useful information in this `issue `_ and DeepSpeed `docs `_. +`DeepSpeed `_ offers additional CUDA deep learning training optimizations, similar to `FairScale `_. DeepSpeed offers lower level training optimizations, and useful efficient optimizers such as `1-bit Adam `_. +Using the plugin, we were able to **train model sizes of 10 Billion+ parameters and above**, with a lot of useful information in this `issue `_ and DeepSpeed `docs `_. We recommend using DeepSpeed in environments where speed and memory optimizations are important (such as training large billion parameter models), and where sacrificing flexibility as a tradeoff is acceptable. In addition, we recommend trying :ref:`sharded` first before trying DeepSpeed's further optimizations. To use DeepSpeed, you first need to install DeepSpeed using the commands below. From 41389b98fa2f0e9fd47a29566f1b46559e3abe5a Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 15 Feb 2021 23:11:01 +0000 Subject: [PATCH 10/58] Try fixing issue by moving to torch adam --- pytorch_lightning/plugins/training_type/deepspeed.py | 1 - tests/plugins/test_deepspeed_plugin.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 313904e765981..e8852ed3257bf 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -21,7 +21,6 @@ import torch from torch.nn.parallel import DistributedDataParallel -from torch.optim import Optimizer from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 7d44056f4f708..01b3dec88d7c5 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -20,6 +20,7 @@ def deepspeed_config(): return { "optimizer": { "type": "Adam", + "torch_adam": True, "params": { "lr": 3e-5, "betas": [0.998, 0.999], From b1cf9c08a1249296b55544c65ac75c0a105a7f01 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 00:01:09 +0000 Subject: [PATCH 11/58] Clean up check --- pytorch_lightning/plugins/training_type/deepspeed.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index e8852ed3257bf..9377c920637be 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -126,16 +126,16 @@ def init_deepspeed(self): self._initialize_deepspeed_inference(model) def _init_scheduler_optimizer(self): - optimizer, lightning_scheduler, optimizer_frequencies = self.lightning_module.trainer.init_optimizers( + optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers( self.lightning_module ) - if (len(optimizer) != 1) or (lightning_scheduler is not None and len(lightning_scheduler) != 1): + if (len(optimizers) != 1) or len(schedulers) > 1: raise MisconfigurationException( "DeepSpeed currently only supports single optimizer, single optional scheduler." ) - lightning_scheduler = lightning_scheduler[0]['scheduler'] if lightning_scheduler else None - optimizer = optimizer[0] - return optimizer, lightning_scheduler, optimizer_frequencies + scheduler = schedulers[0]['scheduler'] if len(schedulers) > 1 else None + optimizer = optimizers[0] + return optimizer, scheduler, optimizer_frequencies def _initialize_deepspeed_train(self, model): optimizer, lightning_scheduler, optimizer_frequencies = None, None, None From beea3066e5f5de8877ac778dac4b2cdede79a31c Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 12:05:16 +0000 Subject: [PATCH 12/58] Changes, better APIs! --- docs/source/advanced/multi_gpu.rst | 59 +++++-- .../plugins/training_type/deepspeed.py | 99 +++++++++-- tests/plugins/test_deepspeed_plugin.py | 157 +++++++++++++----- 3 files changed, 246 insertions(+), 69 deletions(-) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 32008857d48c8..80ddc634f8018 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -672,33 +672,65 @@ Internally we re-initialize your optimizers and shard them across your machines .. _deep_speed: -DeepSpeed [EXPERIMENTAL] -^^^^^^^^^^^^^^^^^^^^^^^^ +DeepSpeed +^^^^^^^^^ .. note:: - The DeepSpeed plugin is experimental and the API is subject to change. Please create an `issue `_ if you run into any issues. + The DeepSpeed plugin is in beta and the API is subject to change. Please create an `issue `_ if you run into any issues. `DeepSpeed `_ offers additional CUDA deep learning training optimizations, similar to `FairScale `_. DeepSpeed offers lower level training optimizations, and useful efficient optimizers such as `1-bit Adam `_. -Using the plugin, we were able to **train model sizes of 10 Billion+ parameters and above**, with a lot of useful information in this `issue `_ and DeepSpeed `docs `_. +Using the plugin, we were able to **train model sizes of 10 Billion+ parameters and above**, with a lot of useful information in this `benchmark `_ and DeepSpeed `docs `_. We recommend using DeepSpeed in environments where speed and memory optimizations are important (such as training large billion parameter models), and where sacrificing flexibility as a tradeoff is acceptable. In addition, we recommend trying :ref:`sharded` first before trying DeepSpeed's further optimizations. To use DeepSpeed, you first need to install DeepSpeed using the commands below. .. code-block:: bash - pip install mpi4py deepspeed + pip install deepspeed mpi4py If you run into an issue with the install or later in training, ensure that the CUDA version of the pytorch you've installed matches your locally installed CUDA (you can see which one has been recognized by running ``nvcc --version``). Additionally if you run into any issues installing m4py, ensure you have openmpi installed using ``sudo apt install libopenmpi-dev`` or ``brew install mpich`` before running ``pip install mpi4py``. +ZeRO-Offload +"""""""""""" + Below we show an example of running `ZeRO-Offload `_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption. -For even more speed benefit, they offer an optimized CPU version of ADAM to run the offloaded computation, which is faster than the standard PyTorch implementation. +For even more speed benefit, they offer an optimized CPU version of ADAM to run the offloaded computation, which is faster than the standard PyTorch implementation. By default we enable ZeRO-Offload. + +.. note:: + To use ZeRO-Offload, you must use ``precision=16`` or set precision via `the DeepSpeed config. `_. + +.. code-block:: python + + from pytorch_lightning import Trainer + + model = MyModel() + trainer = Trainer(gpus=4, plugins='deepspeed', precision=16) + trainer.fit(model) + + +This can also be done via the command line using a Pytorch Lightning script: + +.. code-block:: bash -DeepSpeed requires that the optimizers and schedulers are defined within a config file. -We've included the config to enable ZeRO-Offload, as well as set the bucket sizes to reasonable defaults for low VRAM GPUs (less than 7GB). + python train.py --plugins deepspeed --precision 16 --gpus 4 .. note:: - To use ZeRO-Offload to train large models, you must use ``precision=16`` or set precision via `the DeepSpeed config. `_. + We suggest tuning the ``allgather_bucket_size`` parameter and ``reduce_bucket_size`` parameter to find optimum parameters based on your model size. + These control how large a buffer we limit the model to using when reducing gradients/gathering updated parameters. Smaller values will result in less memory, but tradeoff with speed. + + DeepSpeed allocates a reduce buffer size `multiplied by 4.5x `_ so take that into consideration when tweaking the parameters. + + The plugin sets a reasonable default of ``2e8``, which should work for most low VRAM GPUs (less than ``7GB``), allocating roughly ``3.6GB`` of VRAM as buffer. Higher VRAM GPUs should aim for values around 5e8. + + +Custom DeepSpeed Config +""""""""""""""""""""""" + +DeepSpeed allows to use custom optimizers and schedulers that are defined within a config file. This allows you to enable Optimizers such as `1-bit Adam `_. + +.. note:: + All plugin default parameters will be ignored when a config object is passed. All compatible arguments can be seen in the `DeepSpeed docs `_. .. code-block:: python @@ -756,15 +788,8 @@ You can use also use an environment variable via your Pytorch Lightning script: .. code-block:: bash - DEEPSPEED_CONFIG_PATH=/path/to/deepspeed_config.json python train.py plugins=deepspeed - - -.. note:: - We suggest tuning the ``allgather_bucket_size`` parameter and ``reduce_bucket_size`` parameter to find optimum parameters based on your model size. Larger values will be more efficient in terms of throughput time, but will require more memory. - - DeepSpeed allocates a reduce buffer size `multiplied by 4.5x `_ so take that into consideration when tweaking the parameters. + DEEPSPEED_CONFIG_PATH=/path/to/deepspeed_config.json python train.py --plugins deepspeed - As reference, a reduce buffer size of 2e8 means you'll allocate roughly 3.6GB of VRAM for the buffer. ---------- diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 9377c920637be..6aa38b332e2b4 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -17,7 +17,7 @@ import os from pathlib import Path from types import SimpleNamespace -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch.nn.parallel import DistributedDataParallel @@ -66,28 +66,95 @@ class DeepSpeedPlugin(DDPPlugin): def __init__( self, + zero_optimization: bool = True, + stage: int = 2, + cpu_offload: bool = True, + contiguous_gradients: bool = True, + overlap_comm: bool = True, + allgather_partitions: bool = True, + reduce_scatter: bool = True, + allgather_bucket_size: int = 2e8, + reduce_bucket_size: int = 2e8, + zero_allow_untested_optimizer: bool = True, config: Optional[Union[Path, str, dict]] = None, logging_level: int = logging.WARN, num_nodes: int = 1, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, ) -> None: + """ + + Provides capabilities to run training using the DeepSpeed library, + with training optimizations for large billion parameter models. + `For more information: https://www.deepspeed.ai/`. + + .. warning:: ``DeepSpeedPlugin`` is in beta and subject to change. + + Defaults have been set to enable ZeRO-Offload and some have been taken from the link below. + These defaults have been set generally, but may require tuning for optimum performance based on your model size. + `For more information: https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training`. + + Arguments: + + zero_optimization: Enable ZERO optimization. This is only compatible with precision=16. (default: True) + + stage: Different stages of the ZeRO Optimizer. 0 is disabled, + 1 is optimizer state partitioning, 2 is optimizer+gradient state partitioning (default: 2) + + cpu_offload: Enable offloading optimizer memory and computation to CPU (default: True) + + contiguous_gradients: Copies gradients to a continuous buffer as they are produced. + Avoids memory fragmentation during backwards. Useful when training large models.(default: True) + + overlap_comm: Overlap the reduction(synchronization) of gradients with the backwards computation. + This is a speed optimization when training across multiple GPUs/machines. (default: True) + + allgather_partitions: All gather updated parameters at the end of training step, + instead of using a series of broadcast collectives (default: True) + + reduce_scatter: Use reduce/scatter instead of allreduce to average gradients (default:True) + + allgather_bucket_size: Number of elements to allgather at once. + Used to limit the memory required for larger model sizes, with a tradeoff with speed. (default: 2e8) + + reduce_bucket_size: Number of elements to reduce at once. + Used to limit the memory required for larger model sizes, with a tradeoff with speed (default: 2e8) + + zero_allow_untested_optimizer: Allow untested optimizers to be used with ZERO. Currently only Adam is a + supported Optimizer (default: True) + + config: Pass in a deepspeed formatted config dict, + or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json. + All defaults will be ignored if a config is passed in. (Default: ``None``) + + logging_level: Set logging level for deepspeed. (Default: ``logging.WARN``) + + """ super().__init__( parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment ) self.config = self._load_config(config) + if self.config is None: + # User has not overridden config, set defaults + self.config = self._create_default_config( + zero_optimization, + zero_allow_untested_optimizer, + stage=stage, + cpu_offload=cpu_offload, + contiguous_gradients=contiguous_gradients, + overlap_comm=overlap_comm, + allgather_partitions=allgather_partitions, + reduce_scatter=reduce_scatter, + allgather_bucket_size=allgather_bucket_size, + reduce_bucket_size=reduce_bucket_size + ) self._config_initialized = False deepspeed.utils.logging.logger.setLevel(logging_level) def _load_config(self, config): - if config is None: - if self.DEEPSPEED_ENV_VAR not in os.environ: - raise MisconfigurationException( - "You did not pass a DeepSpeed config object or path for DeepSpeed. This can be passed" - " via instantiating the `DeepSpeedPlugin` object, or by the DEEPSPEED_CONFIG_PATH env variable." - " See: https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html#deepspeed" - ) - config = os.environ.get(self.DEEPSPEED_ENV_VAR) + if config is None and self.DEEPSPEED_ENV_VAR in os.environ: + rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") + config = os.environ[self.DEEPSPEED_ENV_VAR] if isinstance(config, str) or isinstance(config, Path): if os.path.exists(config): with open(config) as f: @@ -133,7 +200,7 @@ def _init_scheduler_optimizer(self): raise MisconfigurationException( "DeepSpeed currently only supports single optimizer, single optional scheduler." ) - scheduler = schedulers[0]['scheduler'] if len(schedulers) > 1 else None + scheduler = schedulers[0]['scheduler'] if len(schedulers) == 1 else None optimizer = optimizers[0] return optimizer, scheduler, optimizer_frequencies @@ -145,7 +212,6 @@ def _initialize_deepspeed_train(self, model): "Using `configure_optimizers` to define optimizer and scheduler." ) optimizer, lightning_scheduler, optimizer_frequencies = self._init_scheduler_optimizer() - model_parameters = filter(lambda p: p.requires_grad, self.model.parameters()) model, optimizer, _, lr_scheduler = deepspeed.initialize( args=SimpleNamespace(local_rank=self.local_rank), @@ -207,7 +273,7 @@ def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Calla self.model.step(**kwargs) def _format_config(self): - if not self.config: + if self.config is None: raise MisconfigurationException( "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config." " See: https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html#deepspeed" @@ -241,3 +307,12 @@ def _format_precision_config(self): "enabled": True, "opt_level": amp_level, } + if "zero_optimization" in self.config and not ("amp" in self.config or "fp16" in self.config): + raise MisconfigurationException("To use DeepSpeed ZeRO Optimization, you must set precision=16.") + + def _create_default_config( + self, zero_optimization: bool, zero_allow_untested_optimizer: bool, **zero_kwargs + ) -> Dict: + if zero_optimization: + return {"zero_allow_untested_optimizer": zero_allow_untested_optimizer, "zero_optimization": zero_kwargs} + return {} diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 01b3dec88d7c5..ddbf8f0128d76 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -40,6 +40,30 @@ def deepspeed_config(): } +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_deepspeed_plugin_string(tmpdir): + """ + Test to ensure that the plugin can be passed via string, and parallel devices is correctly set. + """ + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) + assert trainer.accelerator_backend.training_type_plugin.parallel_devices == [torch.device('cpu')] + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + plugins='deepspeed', + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") def test_deepspeed_plugin(tmpdir): """ @@ -56,7 +80,7 @@ def on_fit_start(self, trainer, pl_module): model = BoringModel() trainer = Trainer( fast_dev_run=True, - plugins=[DeepSpeedPlugin(config={})], + plugins=[DeepSpeedPlugin()], callbacks=[CB()], ) @@ -111,9 +135,7 @@ def on_fit_start(self, trainer, pl_module): raise SystemExit() model = BoringModel() - trainer = Trainer( - fast_dev_run=True, plugins=[DeepSpeedPlugin(config={})], callbacks=[CB()], amp_backend='native', precision=16 - ) + trainer = Trainer(fast_dev_run=True, plugins='deepspeed', callbacks=[CB()], amp_backend='native', precision=16) with pytest.raises(SystemExit): trainer.fit(model) @@ -136,35 +158,12 @@ def on_fit_start(self, trainer, pl_module): raise SystemExit() model = BoringModel() - trainer = Trainer( - fast_dev_run=True, plugins=[DeepSpeedPlugin(config={})], callbacks=[CB()], amp_backend='apex', precision=16 - ) + trainer = Trainer(fast_dev_run=True, plugins='deepspeed', callbacks=[CB()], amp_backend='apex', precision=16) with pytest.raises(SystemExit): trainer.fit(model) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") -def test_invalid_deepspeed_without_config(tmpdir): - """ - Test to ensure if a DeepSpeed config is not provided, we throw an exception. - """ - model = BoringModel() - trainer = Trainer( - fast_dev_run=True, - plugins=[DeepSpeedPlugin(config={})], - gpus=1, - ) - - with pytest.raises( - MisconfigurationException, - match="To use DeepSpeed you must pass in a DeepSpeed config dictionary, or path to a json config." - ): - trainer.fit(model) - - @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") def test_deepspeed_with_invalid_config_path(tmpdir): """ @@ -190,21 +189,60 @@ def test_deepspeed_with_env_path(tmpdir, monkeypatch, deepspeed_config): assert plugin.config == deepspeed_config +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_deepspeed_defaults(tmpdir): + """ + Ensure that defaults are correctly set as a config for DeepSpeed if no arguments are passed. + """ + plugin = DeepSpeedPlugin() + assert plugin.config is not None + assert isinstance(plugin.config["zero_optimization"], dict) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") -@pytest.mark.skipif( - not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" -) -def test_deepspeed(tmpdir, deepspeed_config): +def test_invalid_deepspeed_defaults_no_precision(tmpdir): """ - Test to ensure deepspeed works correctly with a valid config object, - and saves the model weights to load correctly. + Test to ensure that using defaults, if precision is not set to 16, we throw an exception. """ model = BoringModel() trainer = Trainer( - plugins=[DeepSpeedPlugin(deepspeed_config)], + fast_dev_run=True, + plugins='deepspeed', + gpus=1, + ) + with pytest.raises( + MisconfigurationException, match='To use DeepSpeed ZeRO Optimization, you must set precision=16.' + ): + trainer.fit(model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_deepspeed_run_configure_optimizers(tmpdir): + """ + Test to end to end that deepspeed works with defaults, + whilst using configure_optimizers for optimizers and schedulers. + """ + + class TestModel(BoringModel): + + def on_train_start(self) -> None: + from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer + assert isinstance(self.trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) + assert isinstance(self.trainer.optimizers[0].optimizer, torch.optim.SGD) + assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally + # Ensure DeepSpeed engine has initialized with our optimizer/lr_scheduler + assert isinstance(self.trainer.model.optimizer, FP16_DeepSpeedZeroOptimizer) + assert isinstance(self.trainer.model.lr_scheduler, torch.optim.lr_scheduler.StepLR) + + model = TestModel() + trainer = Trainer( + plugins='deepspeed', gpus=1, + precision=16, fast_dev_run=True, ) @@ -213,12 +251,50 @@ def test_deepspeed(tmpdir, deepspeed_config): checkpoint_path = os.path.join(tmpdir, 'model.pt') trainer.save_checkpoint(checkpoint_path) saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + model = model.cpu().float() # Assert model parameters are identical after loading for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()): assert torch.equal(orig_param, trained_model_param) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_deepspeed_config(tmpdir, deepspeed_config): + """ + Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers + and saves the model weights to load correctly. + """ + + class TestModel(BoringModel): + + def on_train_start(self) -> None: + import deepspeed + assert isinstance(self.trainer.optimizers[0], deepspeed.ops.adam.FusedAdam) + assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally + assert isinstance(self.trainer.model.optimizer, deepspeed.ops.adam.FusedAdam) + assert isinstance(self.trainer.model.lr_scheduler, deepspeed.runtime.lr_schedules.WarmupLR) + + model = TestModel() + trainer = Trainer( + plugins=[DeepSpeedPlugin(config=deepspeed_config)], + gpus=1, + fast_dev_run=True, + ) + + trainer.fit(model) + trainer.test(model) + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + model = model.cpu() + # Assert model parameters are identical after loading + for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(orig_param, trained_model_param) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") @@ -226,12 +302,12 @@ def test_deepspeed(tmpdir, deepspeed_config): @pytest.mark.skipif( not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" ) -def test_deepspeed_offload_zero_multigpu(tmpdir, deepspeed_config): +def test_deepspeed_offload_zero_multigpu_config(tmpdir, deepspeed_config): """ - Test to ensure that zero offload with multiple GPUs works correctly. + Test to ensure that zero offload with multiple GPUs works correctly if a user passes a ZeRO enabled config. """ deepspeed_config = { - **deepspeed_config, "zero_allow_untested_optimizer": False, + **deepspeed_config, "zero_allow_untested_optimizer": True, "zero_optimization": { "stage": 2, "cpu_offload": True, @@ -241,12 +317,13 @@ def test_deepspeed_offload_zero_multigpu(tmpdir, deepspeed_config): } model = BoringModel() trainer = Trainer( - plugins=[DeepSpeedPlugin(deepspeed_config)], + plugins=[DeepSpeedPlugin(config=deepspeed_config)], gpus=2, fast_dev_run=True, precision=16, ) trainer.fit(model) + trainer.test(model) checkpoint_path = os.path.join(tmpdir, 'model.pt') trainer.save_checkpoint(checkpoint_path) @@ -254,7 +331,7 @@ def test_deepspeed_offload_zero_multigpu(tmpdir, deepspeed_config): if trainer.global_rank == 0: saved_model = BoringModel.load_from_checkpoint(checkpoint_path) saved_model = saved_model.float() - model = model.float() + model = model.float().cpu() # Assert model parameters are identical after loading for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()): assert torch.equal(orig_param, trained_model_param) From 2c659fef5439463ab4e7fe119a56ca4859be3a2d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 12:31:17 +0000 Subject: [PATCH 13/58] Add wrapper, swap to git install revision --- docs/source/advanced/multi_gpu.rst | 2 +- requirements/extra.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 80ddc634f8018..45510b762f37b 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -721,7 +721,7 @@ This can also be done via the command line using a Pytorch Lightning script: DeepSpeed allocates a reduce buffer size `multiplied by 4.5x `_ so take that into consideration when tweaking the parameters. - The plugin sets a reasonable default of ``2e8``, which should work for most low VRAM GPUs (less than ``7GB``), allocating roughly ``3.6GB`` of VRAM as buffer. Higher VRAM GPUs should aim for values around 5e8. + The plugin sets a reasonable default of ``2e8``, which should work for most low VRAM GPUs (less than ``7GB``), allocating roughly ``3.6GB`` of VRAM as buffer. Higher VRAM GPUs should aim for values around ``5e8``. Custom DeepSpeed Config diff --git a/requirements/extra.txt b/requirements/extra.txt index f37fda47ee0a0..df0fc70282898 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -8,4 +8,4 @@ onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip -https://github.com/microsoft/DeepSpeed/archive/ec8b1cb0a0a5752bba029da4bdc91616c0f5bec7.zip # TODO: move to DeepSpeed release version +git+https://github.com/microsoft/DeepSpeed.git#commit@ec8b1cb # TODO: move to DeepSpeed release version From 4b295cd1c604e96d4bd2240bb9c9bbb35838cca3 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 12:38:55 +0000 Subject: [PATCH 14/58] Add special test --- tests/special_tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 3ad6e65512585..f7ed2151bf8ab 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -17,6 +17,7 @@ export PL_RUNNING_SPECIAL_TESTS=1 DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_offload_zero_multigpu_config python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual_amp From caeac52334b699b7f473c6a6ed03c04ee13bbb1e Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 13:48:05 +0000 Subject: [PATCH 15/58] Add warning --- .../plugins/precision/deepspeed_precision.py | 11 ++++++-- tests/plugins/test_deepspeed_plugin.py | 26 +++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 72ca4a526df6d..5a249cf122818 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -5,6 +5,10 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.warnings import WarningCache + +warning_cache = WarningCache() class DeepSpeedPrecisionPlugin(PrecisionPlugin): @@ -36,9 +40,12 @@ def backward( *args, **kwargs, ): + if is_overridden('backward', lightning_module): + warning_cache.warn( + "Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles" + "backward logic outside of the LightningModule" + ) # todo: hack around for deepspeed engine to call backward - # Means that the lightning module backward function is never called - # This is an issue if the user overrides the backwards function deepspeed_engine = lightning_module.trainer.model deepspeed_engine.backward(closure_loss, **kwargs) # once backward has been applied, release graph diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index ddbf8f0128d76..a7b7d8473086d 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -4,6 +4,8 @@ import pytest import torch +from torch import Tensor +from torch.optim import Optimizer from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback @@ -218,6 +220,30 @@ def test_invalid_deepspeed_defaults_no_precision(tmpdir): trainer.fit(model) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +def test_warn_deepspeed_override_backward(tmpdir): + """ + Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning. + """ + + class TestModel(BoringModel): + + def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: + return loss.backward() + + model = TestModel() + trainer = Trainer( + fast_dev_run=True, + plugins='deepspeed', + precision=16, + gpus=1, + ) + with pytest.warns(UserWarning, match='Overridden backward hook in the LightningModule will be ignored'): + trainer.fit(model) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") From fcc2f993d5b16b87a6fc15938e7518f658f1bbc9 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 14:42:37 +0000 Subject: [PATCH 16/58] Address review --- pytorch_lightning/plugins/training_type/deepspeed.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 6aa38b332e2b4..ec79103b3ced6 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -225,7 +225,6 @@ def _initialize_deepspeed_train(self, model): # set optimizer for save/load, but deepspeed manages the specific optimizer logic trainer = self.lightning_module.trainer trainer.optimizers = [optimizer] - trainer.convert_to_lightning_optimizers() self.model = model def _initialize_deepspeed_inference(self, model): @@ -270,6 +269,8 @@ def init_optimizers(self, trainer: "Trainer", model: LightningModule) -> Tuple[L return [], [], [] # empty optimizers, schedulers and frequencies def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): + # note: We rely on the deepspeed engine to carry out the step rather than the optimizer. + # internally, the engine has a reference to the optimizer already. self.model.step(**kwargs) def _format_config(self): From 91cc1e055ac1c205b37d75f000937c3c6429f6ed Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 15:14:59 +0000 Subject: [PATCH 17/58] Add better disclaimer --- pytorch_lightning/plugins/training_type/deepspeed.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index ec79103b3ced6..1f9a41db6b456 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -289,8 +289,10 @@ def _format_batch_size_and_grad_accum_config(self): " as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer." ) if "train_micro_batch_size_per_gpu" not in self.config: - # train_micro_batch_size_per_gpu is used for logging purposes, use loader batch size - self.config["train_micro_batch_size_per_gpu"] = self.lightning_module.train_dataloader().batch_size + # train_micro_batch_size_per_gpu is used for throughput logging purposes + # by default we use the batch size of the loader which may be incorrect if a batch sampler is passed + batch_size = self.lightning_module.train_dataloader().batch_size + self.config["train_micro_batch_size_per_gpu"] = batch_size self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches if "gradient_clipping" not in self.config: self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val From 37542d65746cc8b1ab6d133db4eb62efbdb8c831 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 16:32:59 +0000 Subject: [PATCH 18/58] Turn off ZeRO for testing due to compilation --- tests/plugins/test_deepspeed_plugin.py | 35 +++++++------------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index a7b7d8473086d..0a35660efac21 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -21,13 +21,9 @@ def deepspeed_config(): return { "optimizer": { - "type": "Adam", - "torch_adam": True, + "type": "SGD", "params": { "lr": 3e-5, - "betas": [0.998, 0.999], - "eps": 1e-5, - "weight_decay": 1e-9, }, }, 'scheduler': { @@ -249,26 +245,22 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") def test_deepspeed_run_configure_optimizers(tmpdir): """ - Test to end to end that deepspeed works with defaults, + Test to end to end that deepspeed works with defaults (without ZeRO as that requires compilation), whilst using configure_optimizers for optimizers and schedulers. """ class TestModel(BoringModel): def on_train_start(self) -> None: - from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer - assert isinstance(self.trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) - assert isinstance(self.trainer.optimizers[0].optimizer, torch.optim.SGD) + assert isinstance(self.trainer.optimizers[0], torch.optim.SGD) assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally # Ensure DeepSpeed engine has initialized with our optimizer/lr_scheduler - assert isinstance(self.trainer.model.optimizer, FP16_DeepSpeedZeroOptimizer) assert isinstance(self.trainer.model.lr_scheduler, torch.optim.lr_scheduler.StepLR) model = TestModel() trainer = Trainer( - plugins='deepspeed', + plugins=DeepSpeedPlugin(zero_optimization=False), gpus=1, - precision=16, fast_dev_run=True, ) @@ -297,9 +289,9 @@ class TestModel(BoringModel): def on_train_start(self) -> None: import deepspeed - assert isinstance(self.trainer.optimizers[0], deepspeed.ops.adam.FusedAdam) + assert isinstance(self.trainer.optimizers[0], torch.optim.SGD) assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally - assert isinstance(self.trainer.model.optimizer, deepspeed.ops.adam.FusedAdam) + assert isinstance(self.trainer.model.optimizer, torch.optim.SGD) assert isinstance(self.trainer.model.lr_scheduler, deepspeed.runtime.lr_schedules.WarmupLR) model = TestModel() @@ -328,22 +320,13 @@ def on_train_start(self) -> None: @pytest.mark.skipif( not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" ) -def test_deepspeed_offload_zero_multigpu_config(tmpdir, deepspeed_config): +def test_deepspeed_multigpu(tmpdir, deepspeed_config): """ - Test to ensure that zero offload with multiple GPUs works correctly if a user passes a ZeRO enabled config. + Test to ensure that DeepSpeed with multiple GPUs works, without ZeRO Optimization as this requires compilation. """ - deepspeed_config = { - **deepspeed_config, "zero_allow_untested_optimizer": True, - "zero_optimization": { - "stage": 2, - "cpu_offload": True, - "contiguous_gradients": True, - "overlap_comm": True - } - } model = BoringModel() trainer = Trainer( - plugins=[DeepSpeedPlugin(config=deepspeed_config)], + plugins=[DeepSpeedPlugin(zero_optimization=False)], gpus=2, fast_dev_run=True, precision=16, From a11695df5564a1a78cd0986ebeeaa3a9932f9c23 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 16:54:11 +0000 Subject: [PATCH 19/58] Add description on modifying parameters via the plugin --- docs/source/advanced/multi_gpu.rst | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 45510b762f37b..b7369b47075da 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -715,6 +715,19 @@ This can also be done via the command line using a Pytorch Lightning script: python train.py --plugins deepspeed --precision 16 --gpus 4 + +Modify ZeRO parameters via the plugin as below. + +.. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.plugins import DeepSpeedPlugin + + model = MyModel() + trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(allgather_bucket_size=5e8, reduce_bucket_size=5e8), precision=16) + trainer.fit(model) + + .. note:: We suggest tuning the ``allgather_bucket_size`` parameter and ``reduce_bucket_size`` parameter to find optimum parameters based on your model size. These control how large a buffer we limit the model to using when reducing gradients/gathering updated parameters. Smaller values will result in less memory, but tradeoff with speed. From f4585f0d157590206f0e9cd4ddc5e327ddf0bf02 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 17:59:49 +0000 Subject: [PATCH 20/58] Doc strings clear --- pytorch_lightning/plugins/training_type/deepspeed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 1f9a41db6b456..0a5c2007021e8 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -120,8 +120,8 @@ def __init__( reduce_bucket_size: Number of elements to reduce at once. Used to limit the memory required for larger model sizes, with a tradeoff with speed (default: 2e8) - zero_allow_untested_optimizer: Allow untested optimizers to be used with ZERO. Currently only Adam is a - supported Optimizer (default: True) + zero_allow_untested_optimizer: Allow untested optimizers to be used with ZeRO. Currently only Adam is a + DeepSpeed supported optimizer when using ZeRO (default: True) config: Pass in a deepspeed formatted config dict, or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json. From 52b654d4dd438cceb528739321bf8a7d53600211 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 19:22:32 +0000 Subject: [PATCH 21/58] Small doc fixes --- docs/source/advanced/multi_gpu.rst | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index b7369b47075da..c9b95ebd353c1 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -679,8 +679,8 @@ DeepSpeed The DeepSpeed plugin is in beta and the API is subject to change. Please create an `issue `_ if you run into any issues. `DeepSpeed `_ offers additional CUDA deep learning training optimizations, similar to `FairScale `_. DeepSpeed offers lower level training optimizations, and useful efficient optimizers such as `1-bit Adam `_. -Using the plugin, we were able to **train model sizes of 10 Billion+ parameters and above**, with a lot of useful information in this `benchmark `_ and DeepSpeed `docs `_. -We recommend using DeepSpeed in environments where speed and memory optimizations are important (such as training large billion parameter models), and where sacrificing flexibility as a tradeoff is acceptable. In addition, we recommend trying :ref:`sharded` first before trying DeepSpeed's further optimizations. +Using the plugin, we were able to **train model sizes of 10 Billion parameters and above**, with a lot of useful information in this `benchmark `_ and the DeepSpeed `docs `_. +We recommend using DeepSpeed in environments where speed and memory optimizations are important (such as training large billion parameter models). In addition, we recommend trying :ref:`sharded` first before trying DeepSpeed's further optimizations, primarily due to FairScale Sharded ease of use in scenarios such as multiple optimizers/schedulers. To use DeepSpeed, you first need to install DeepSpeed using the commands below. @@ -716,7 +716,7 @@ This can also be done via the command line using a Pytorch Lightning script: python train.py --plugins deepspeed --precision 16 --gpus 4 -Modify ZeRO parameters via the plugin as below. +You can also modify the ZeRO-Offload parameters via the plugin as below. .. code-block:: python @@ -740,7 +740,7 @@ Modify ZeRO parameters via the plugin as below. Custom DeepSpeed Config """"""""""""""""""""""" -DeepSpeed allows to use custom optimizers and schedulers that are defined within a config file. This allows you to enable Optimizers such as `1-bit Adam `_. +DeepSpeed allows use of custom DeepSpeed optimizers and schedulers defined within a config file. This allows you to enable optimizers such as `1-bit Adam `_. .. note:: All plugin default parameters will be ignored when a config object is passed. @@ -752,13 +752,15 @@ DeepSpeed allows to use custom optimizers and schedulers that are defined within from pytorch_lightning.plugins import DeepSpeedPlugin deepspeed_config = { + "zero_allow_untested_optimizer": True, "optimizer": { - "type": "Adam", + "type": "OneBitAdam", "params": { "lr": 3e-5, "betas": [0.998, 0.999], "eps": 1e-5, "weight_decay": 1e-9, + "cuda_aware": True, }, }, 'scheduler': { @@ -774,7 +776,7 @@ DeepSpeed allows to use custom optimizers and schedulers that are defined within "stage": 2, # Enable Stage 2 ZeRO (Optimizer/Gradient state partitioning) "cpu_offload": True, # Enable Offloading optimizer state/calculation to the host CPU "contiguous_gradients": True, # Reduce gradient fragmentation. - "overlap_comm": True # Overlap reduce/backward operation of gradients for speed. + "overlap_comm": True, # Overlap reduce/backward operation of gradients for speed. "allgather_bucket_size": 2e8, # Number of elements to all gather at once. "reduce_bucket_size": 2e8, # Number of elements we reduce/allreduce at once. } From b06bd2c28e454e5ec72ee84a7d8214902b6047fc Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 21:06:39 +0000 Subject: [PATCH 22/58] Fix hash, reduce test --- requirements/extra.txt | 2 +- tests/plugins/test_deepspeed_plugin.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/requirements/extra.txt b/requirements/extra.txt index 64d69d914eefe..08d8f2110f527 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -8,4 +8,4 @@ onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip -git+https://github.com/microsoft/DeepSpeed.git#commit@ec8b1cb # TODO: move to DeepSpeed release version +git+https://github.com/microsoft/DeepSpeed.git@ec8b1cb # TODO: move to DeepSpeed release version diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 0a35660efac21..1228423747a77 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -232,8 +232,7 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args model = TestModel() trainer = Trainer( fast_dev_run=True, - plugins='deepspeed', - precision=16, + plugins=DeepSpeedPlugin(zero_optimization=False), gpus=1, ) with pytest.warns(UserWarning, match='Overridden backward hook in the LightningModule will be ignored'): From 83535fbe3a3afdaf63d46031882c0bbacf6b70e6 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 21:37:30 +0000 Subject: [PATCH 23/58] Added CI change --- .github/workflows/ci_test-full.yml | 7 +++++++ requirements/extra.txt | 1 - 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 300a0748dcda3..3bced7a8c8e91 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -112,6 +112,13 @@ jobs: pip list shell: bash + - name: Install DeepSpeed + # todo: This is a temporary fix to install DeepSpeed outside of the extras.txt package for CI. + if: runner.os != 'windows' + run: | + pip install git+https://github.com/microsoft/DeepSpeed.git@ec8b1cb + shell: bash + - name: Reinstall Horovod if necessary if: runner.os != 'windows' env: diff --git a/requirements/extra.txt b/requirements/extra.txt index 08d8f2110f527..0e7dffbcb39b0 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -8,4 +8,3 @@ onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip -git+https://github.com/microsoft/DeepSpeed.git@ec8b1cb # TODO: move to DeepSpeed release version From a1e487d0659ebc6eca59fc8abece86ad092984b6 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 22:05:12 +0000 Subject: [PATCH 24/58] Move to azure pipeline --- .github/workflows/ci_test-full.yml | 7 ------- azure-pipelines.yml | 6 ++++++ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index 3bced7a8c8e91..300a0748dcda3 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -112,13 +112,6 @@ jobs: pip list shell: bash - - name: Install DeepSpeed - # todo: This is a temporary fix to install DeepSpeed outside of the extras.txt package for CI. - if: runner.os != 'windows' - run: | - pip install git+https://github.com/microsoft/DeepSpeed.git@ec8b1cb - shell: bash - - name: Reinstall Horovod if necessary if: runner.os != 'windows' env: diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 8a6f1324521b0..c7db435873a82 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -61,6 +61,12 @@ jobs: pip list displayName: 'Install dependencies' + - bash: | + # temporary fix till deepspeed next stable release + sudo apt install libopenmpi-dev + pip install deepspeed mpi4py + displayName: 'Install DeepSpeed' + - script: | python tests/collect_env_details.py displayName: 'Env details' From 535800c46799d57237893825fe56ccbc003d455b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 22:05:46 +0000 Subject: [PATCH 25/58] Fix test name --- tests/special_tests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index e7b12f3a5319a..ffb21255a6d3c 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -17,7 +17,7 @@ export PL_RUNNING_SPECIAL_TESTS=1 DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp -python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_offload_zero_multigpu_config +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_multigpu python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual_amp From 471ccdfd4e7f3db7a9d2fa002f607b83137fbfb9 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 22:12:45 +0000 Subject: [PATCH 26/58] Add missing flag --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index c7db435873a82..f81248e77913b 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -63,7 +63,7 @@ jobs: - bash: | # temporary fix till deepspeed next stable release - sudo apt install libopenmpi-dev + sudo apt install libopenmpi-dev -y pip install deepspeed mpi4py displayName: 'Install DeepSpeed' From e458c192f8d3232edab305482663629fca0cd114 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 22:17:03 +0000 Subject: [PATCH 27/58] Remove sudo... --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index f81248e77913b..51d3f76e25463 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -63,7 +63,7 @@ jobs: - bash: | # temporary fix till deepspeed next stable release - sudo apt install libopenmpi-dev -y + apt install libopenmpi-dev -y pip install deepspeed mpi4py displayName: 'Install DeepSpeed' From 9826ca893d74b33fb1c9b7a58c36608b7ba16414 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 22:27:45 +0000 Subject: [PATCH 28/58] Try conda instead --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 51d3f76e25463..287f6941565e0 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -63,7 +63,7 @@ jobs: - bash: | # temporary fix till deepspeed next stable release - apt install libopenmpi-dev -y + conda install mpi4py -y pip install deepspeed mpi4py displayName: 'Install DeepSpeed' From 45ea29002a1c0bd76797a403eba630265cb4e972 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 22:41:16 +0000 Subject: [PATCH 29/58] Swap to conda base --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 287f6941565e0..e36cb0c41c359 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -32,7 +32,7 @@ jobs: #container: "pytorchlightning/pytorch_lightning:base-cuda-py$[ variables['python.version'] ]-torch1.6" container: # base ML image: mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn8-ubuntu18.04 - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.7-torch1.6" + image: "pytorchlightning/pytorch_lightning:base-conda-py3.7-torch1.6" #endpoint: azureContainerRegistryConnection options: "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all" From 9272a95f8a823b7b601455832e01f27a97d25d70 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 16 Feb 2021 23:09:43 +0000 Subject: [PATCH 30/58] Try suggested install --- azure-pipelines.yml | 8 +------- requirements/extra.txt | 1 + 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index e36cb0c41c359..8a6f1324521b0 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -32,7 +32,7 @@ jobs: #container: "pytorchlightning/pytorch_lightning:base-cuda-py$[ variables['python.version'] ]-torch1.6" container: # base ML image: mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn8-ubuntu18.04 - image: "pytorchlightning/pytorch_lightning:base-conda-py3.7-torch1.6" + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.7-torch1.6" #endpoint: azureContainerRegistryConnection options: "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all" @@ -61,12 +61,6 @@ jobs: pip list displayName: 'Install dependencies' - - bash: | - # temporary fix till deepspeed next stable release - conda install mpi4py -y - pip install deepspeed mpi4py - displayName: 'Install DeepSpeed' - - script: | python tests/collect_env_details.py displayName: 'Env details' diff --git a/requirements/extra.txt b/requirements/extra.txt index 0e7dffbcb39b0..f8c0846ebaf84 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -8,3 +8,4 @@ onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip +deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb From e06ec29ec8f8693e07df9be4763829a9da34bbce Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 17 Feb 2021 00:27:14 +0100 Subject: [PATCH 31/58] Apply suggestions from code review --- requirements/extra.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/extra.txt b/requirements/extra.txt index f8c0846ebaf84..5413e9bb07b0e 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -8,4 +8,4 @@ onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip -deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb +deepspeed@https://github.com/microsoft/DeepSpeed/archive/ec8b1cb.zip From 41cca05ae9f5c4210b4069f1df6b2118f3ff938a Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 17 Feb 2021 00:35:27 +0100 Subject: [PATCH 32/58] Apply suggestions from code review --- requirements/extra.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/extra.txt b/requirements/extra.txt index 5413e9bb07b0e..61fdc87aba032 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -8,4 +8,4 @@ onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip -deepspeed@https://github.com/microsoft/DeepSpeed/archive/ec8b1cb.zip +https://github.com/microsoft/DeepSpeed/archive/ec8b1cb.zip From 37f2d9d35907a4baf59ce4c31103eec31126854b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 00:04:15 +0000 Subject: [PATCH 33/58] Revert "Apply suggestions from code review" This reverts commit 41cca05a --- requirements/extra.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/extra.txt b/requirements/extra.txt index 61fdc87aba032..5413e9bb07b0e 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -8,4 +8,4 @@ onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip -https://github.com/microsoft/DeepSpeed/archive/ec8b1cb.zip +deepspeed@https://github.com/microsoft/DeepSpeed/archive/ec8b1cb.zip From c0155307ceaed3d6da40581dfafda082109b683d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 00:04:23 +0000 Subject: [PATCH 34/58] Revert "Apply suggestions from code review" This reverts commit e06ec29e --- requirements/extra.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/extra.txt b/requirements/extra.txt index 5413e9bb07b0e..f8c0846ebaf84 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -8,4 +8,4 @@ onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip -deepspeed@https://github.com/microsoft/DeepSpeed/archive/ec8b1cb.zip +deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb From 054b320afa82a0d5304cd29abe1fde2d9f8d7efa Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 00:35:16 +0000 Subject: [PATCH 35/58] Remove setter --- pytorch_lightning/plugins/training_type/deepspeed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 0a5c2007021e8..3ac15700d5832 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -42,7 +42,6 @@ class LightningDeepSpeedModule(_LightningModuleWrapperBase): def __init__(self, pl_module: LightningModule, precision: int): super().__init__(pl_module) - self.module = pl_module self.precision = precision def forward(self, *inputs, **kwargs): From 301c32d0b609ac5b60fcf2e427b9af61e442c764 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 10:46:01 +0000 Subject: [PATCH 36/58] Address most review --- docs/source/advanced/multi_gpu.rst | 2 +- .../plugins/training_type/deepspeed.py | 17 +++--- pytorch_lightning/utilities/imports.py | 2 +- requirements/extra.txt | 2 +- tests/plugins/test_deepspeed_plugin.py | 53 +++++-------------- 5 files changed, 26 insertions(+), 50 deletions(-) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 37baa5175eb91..cd3e9c72b3cb3 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -803,7 +803,7 @@ You can use also use an environment variable via your Pytorch Lightning script: .. code-block:: bash - DEEPSPEED_CONFIG_PATH=/path/to/deepspeed_config.json python train.py --plugins deepspeed + PL_DEEPSPEED_CONFIG_PATH=/path/to/deepspeed_config.json python train.py --plugins deepspeed ---------- diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 3ac15700d5832..354ef5944ef42 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -34,8 +34,6 @@ if _DEEPSPEED_AVAILABLE: import deepspeed -else: - deepspeed = None class LightningDeepSpeedModule(_LightningModuleWrapperBase): @@ -61,7 +59,7 @@ def _move_float_tensors_to_half(self, batch: Any): class DeepSpeedPlugin(DDPPlugin): distributed_backend = "deepspeed" - DEEPSPEED_ENV_VAR = "DEEPSPEED_CONFIG_PATH" + DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH" def __init__( self, @@ -95,7 +93,7 @@ def __init__( Arguments: - zero_optimization: Enable ZERO optimization. This is only compatible with precision=16. (default: True) + zero_optimization: Enable ZeRO optimization. This is only compatible with precision=16. (default: True) stage: Different stages of the ZeRO Optimizer. 0 is disabled, 1 is optimizer state partitioning, 2 is optimizer+gradient state partitioning (default: 2) @@ -103,9 +101,9 @@ def __init__( cpu_offload: Enable offloading optimizer memory and computation to CPU (default: True) contiguous_gradients: Copies gradients to a continuous buffer as they are produced. - Avoids memory fragmentation during backwards. Useful when training large models.(default: True) + Avoids memory fragmentation during backwards. Useful when training large models. (default: True) - overlap_comm: Overlap the reduction(synchronization) of gradients with the backwards computation. + overlap_comm: Overlap the reduction (synchronization) of gradients with the backwards computation. This is a speed optimization when training across multiple GPUs/machines. (default: True) allgather_partitions: All gather updated parameters at the end of training step, @@ -129,6 +127,11 @@ def __init__( logging_level: Set logging level for deepspeed. (Default: ``logging.WARN``) """ + if not _DEEPSPEED_AVAILABLE: + raise MisconfigurationException( + "To use the DeepSpeed plugin, you must have DeepSpeed installed." + " pip install deepspeed mpi4py" + ) super().__init__( parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment ) @@ -164,7 +167,7 @@ def _load_config(self, config): ) return config - def pre_training(self): + def pre_dispatch(self): self.set_world_ranks() self.init_ddp_connection(self.global_rank, self.world_size) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 4e98f16b3c997..b4c30097fad4e 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -55,7 +55,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_QUANTIZE_AVAILABLE = _module_available('torch.ops.quantized') _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available('pl_bolts') -_DEEPSPEED_AVAILABLE = platform.system() != 'Windows' and _module_available('deepspeed') +_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel') _FAIRSCALE_PIPE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and _compare_version("fairscale", operator.le, "0.1.3") _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group') diff --git a/requirements/extra.txt b/requirements/extra.txt index f8c0846ebaf84..73cf16fb61731 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -8,4 +8,4 @@ onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip -deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb +deepspeed @ git+https://github.com/microsoft/DeepSpeed@ec8b1cb diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 1228423747a77..e3ac3a29bd97e 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -1,6 +1,5 @@ import json import os -import platform import pytest import torch @@ -14,8 +13,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel -PRETEND_N_OF_GPUS = 1 - @pytest.fixture def deepspeed_config(): @@ -94,7 +91,7 @@ def test_deepspeed_plugin_env(tmpdir, monkeypatch, deepspeed_config): config_path = os.path.join(tmpdir, 'temp.json') with open(config_path, 'w') as f: f.write(json.dumps(deepspeed_config)) - monkeypatch.setenv("DEEPSPEED_CONFIG_PATH", config_path) + monkeypatch.setenv("PL_DEEPSPEED_CONFIG_PATH", config_path) class CB(Callback): @@ -116,35 +113,18 @@ def on_fit_start(self, trainer, pl_module): trainer.fit(model) +@pytest.mark.parametrize( + "amp_backend", [ + pytest.param("native", marks=pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")), + pytest.param("apex", marks=pytest.mark.skipif(not _APEX_AVAILABLE, reason="Requires Apex")), + ] +) @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") @pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") -def test_deepspeed_amp_choice(tmpdir): - """ - Test to ensure precision plugin is also correctly chosen. DeepSpeed handles precision via - Custom DeepSpeedPrecisionPlugin - """ - - class CB(Callback): - - def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) - assert isinstance(trainer.accelerator_backend.precision_plugin, DeepSpeedPrecisionPlugin) - assert trainer.accelerator_backend.precision_plugin.precision == 16 - raise SystemExit() - - model = BoringModel() - trainer = Trainer(fast_dev_run=True, plugins='deepspeed', callbacks=[CB()], amp_backend='native', precision=16) - - with pytest.raises(SystemExit): - trainer.fit(model) - - -@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") -@pytest.mark.skipif(not _APEX_AVAILABLE, reason="Requires Apex") -def test_deepspeed_apex_choice(tmpdir): +def test_deepspeed_precision_choice(amp_backend, tmpdir): """ - Test to ensure precision plugin is also correctly chosen. DeepSpeed handles precision via - Custom DeepSpeedPrecisionPlugin + Test to ensure precision plugin is also correctly chosen. + DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin """ class CB(Callback): @@ -156,7 +136,7 @@ def on_fit_start(self, trainer, pl_module): raise SystemExit() model = BoringModel() - trainer = Trainer(fast_dev_run=True, plugins='deepspeed', callbacks=[CB()], amp_backend='apex', precision=16) + trainer = Trainer(fast_dev_run=True, plugins='deepspeed', callbacks=[CB()], amp_backend=amp_backend, precision=16) with pytest.raises(SystemExit): trainer.fit(model) @@ -182,7 +162,7 @@ def test_deepspeed_with_env_path(tmpdir, monkeypatch, deepspeed_config): config_path = os.path.join(tmpdir, 'temp.json') with open(config_path, 'w') as f: f.write(json.dumps(deepspeed_config)) - monkeypatch.setenv("DEEPSPEED_CONFIG_PATH", config_path) + monkeypatch.setenv("PL_DEEPSPEED_CONFIG_PATH", config_path) plugin = DeepSpeedPlugin() assert plugin.config == deepspeed_config @@ -197,8 +177,6 @@ def test_deepspeed_defaults(tmpdir): assert isinstance(plugin.config["zero_optimization"], dict) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") def test_invalid_deepspeed_defaults_no_precision(tmpdir): """ @@ -208,7 +186,6 @@ def test_invalid_deepspeed_defaults_no_precision(tmpdir): trainer = Trainer( fast_dev_run=True, plugins='deepspeed', - gpus=1, ) with pytest.raises( MisconfigurationException, match='To use DeepSpeed ZeRO Optimization, you must set precision=16.' @@ -217,7 +194,6 @@ def test_invalid_deepspeed_defaults_no_precision(tmpdir): @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") def test_warn_deepspeed_override_backward(tmpdir): """ @@ -240,11 +216,10 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") def test_deepspeed_run_configure_optimizers(tmpdir): """ - Test to end to end that deepspeed works with defaults (without ZeRO as that requires compilation), + Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), whilst using configure_optimizers for optimizers and schedulers. """ @@ -276,7 +251,6 @@ def on_train_start(self) -> None: @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") def test_deepspeed_config(tmpdir, deepspeed_config): """ @@ -313,7 +287,6 @@ def on_train_start(self) -> None: @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif( From 3fda074b1d3542b32ca67590c10537b3f7545379 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 10:57:23 +0000 Subject: [PATCH 37/58] Move out function, remove DeepSpeed from requirements --- requirements/extra.txt | 1 - tests/plugins/test_deepspeed_plugin.py | 21 ++++++--------------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/requirements/extra.txt b/requirements/extra.txt index 73cf16fb61731..0e7dffbcb39b0 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -8,4 +8,3 @@ onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip -deepspeed @ git+https://github.com/microsoft/DeepSpeed@ec8b1cb diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index e3ac3a29bd97e..2f55ee4d1f9ba 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -240,14 +240,7 @@ def on_train_start(self) -> None: trainer.fit(model) - checkpoint_path = os.path.join(tmpdir, 'model.pt') - trainer.save_checkpoint(checkpoint_path) - saved_model = BoringModel.load_from_checkpoint(checkpoint_path) - model = model.cpu().float() - - # Assert model parameters are identical after loading - for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()): - assert torch.equal(orig_param, trained_model_param) + _assert_save_model_is_equal(model, tmpdir, trainer) @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @@ -277,13 +270,7 @@ def on_train_start(self) -> None: trainer.fit(model) trainer.test(model) - checkpoint_path = os.path.join(tmpdir, 'model.pt') - trainer.save_checkpoint(checkpoint_path) - saved_model = BoringModel.load_from_checkpoint(checkpoint_path) - model = model.cpu() - # Assert model parameters are identical after loading - for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()): - assert torch.equal(orig_param, trained_model_param) + _assert_save_model_is_equal(model, tmpdir, trainer) @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @@ -306,6 +293,10 @@ def test_deepspeed_multigpu(tmpdir, deepspeed_config): trainer.fit(model) trainer.test(model) + _assert_save_model_is_equal(model, tmpdir, trainer) + + +def _assert_save_model_is_equal(model, tmpdir, trainer): checkpoint_path = os.path.join(tmpdir, 'model.pt') trainer.save_checkpoint(checkpoint_path) # carry out the check only on rank 0 From d969d28b30b98566c3722695792c2a2a99262c76 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 12:30:52 +0000 Subject: [PATCH 38/58] Install deepspeed/mpi4py within container --- azure-pipelines.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 8a6f1324521b0..a7901c1bbc78d 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -61,6 +61,11 @@ jobs: pip list displayName: 'Install dependencies' + - bash: | + # Temporary fix till DeepSpeed release + pip install deepspeed mpi4py + displayName: 'Install DeepSpeed' + - script: | python tests/collect_env_details.py displayName: 'Env details' From 5d993ecc37af6ff5f199c406e998363769f6d9c0 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 14:22:26 +0000 Subject: [PATCH 39/58] Use special tests, move to master commit for deepspeed --- azure-pipelines.yml | 2 +- tests/plugins/test_deepspeed_plugin.py | 9 +++++++++ tests/special_tests.sh | 3 +++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index a7901c1bbc78d..a4b8e22f17af9 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -63,7 +63,7 @@ jobs: - bash: | # Temporary fix till DeepSpeed release - pip install deepspeed mpi4py + pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb displayName: 'Install DeepSpeed' - script: | diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 2f55ee4d1f9ba..9f6f61ca34c07 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -195,6 +195,9 @@ def test_invalid_deepspeed_defaults_no_precision(tmpdir): @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) def test_warn_deepspeed_override_backward(tmpdir): """ Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning. @@ -217,6 +220,9 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) def test_deepspeed_run_configure_optimizers(tmpdir): """ Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), @@ -245,6 +251,9 @@ def on_train_start(self) -> None: @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) def test_deepspeed_config(tmpdir, deepspeed_config): """ Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers diff --git a/tests/special_tests.sh b/tests/special_tests.sh index ffb21255a6d3c..472f7afda5e9e 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -17,6 +17,9 @@ export PL_RUNNING_SPECIAL_TESTS=1 DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_warn_deepspeed_override_backward +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_run_configure_optimizers +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_config python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_multigpu python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual From 62f3048665e34a8179b388bb6c516d70825205ef Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 14:53:24 +0000 Subject: [PATCH 40/58] Export path --- azure-pipelines.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index a4b8e22f17af9..96d0862ad7274 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -64,6 +64,8 @@ jobs: - bash: | # Temporary fix till DeepSpeed release pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb + # Update path to find ninja installation + export PATH=$PATH:/home/AzDevOps_azpcontainer/.local/bin displayName: 'Install DeepSpeed' - script: | From ec7909670d2d26ef871454af13f5af93966e43ca Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 15:55:33 +0000 Subject: [PATCH 41/58] Force compile to happen first --- azure-pipelines.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 96d0862ad7274..744a14ec1e8c7 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -63,9 +63,7 @@ jobs: - bash: | # Temporary fix till DeepSpeed release - pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb - # Update path to find ninja installation - export PATH=$PATH:/home/AzDevOps_azpcontainer/.local/bin + DS_BUILD_OPS=1 pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb displayName: 'Install DeepSpeed' - script: | From 894a6dd5db01f28cb05085071387c655077394c8 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 15:56:45 +0000 Subject: [PATCH 42/58] Remove! --- azure-pipelines.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 744a14ec1e8c7..87661b1696421 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -63,7 +63,10 @@ jobs: - bash: | # Temporary fix till DeepSpeed release - DS_BUILD_OPS=1 pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb + pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb + # Update path to find ninja installation + export PATH=$PATH:/home/AzDevOps_azpcontainer/.local/bin + echo $PATH displayName: 'Install DeepSpeed' - script: | From e735358fa67357f58d6b354369ac5c596a754ea3 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 16:21:26 +0000 Subject: [PATCH 43/58] Debugging ninja --- azure-pipelines.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 87661b1696421..70e260d717945 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -65,8 +65,11 @@ jobs: # Temporary fix till DeepSpeed release pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb # Update path to find ninja installation + ln -s /usr/bin/ninja /home/AzDevOps_azpcontainer/.local/bin/ninja export PATH=$PATH:/home/AzDevOps_azpcontainer/.local/bin echo $PATH + whereis ninja + which ninja displayName: 'Install DeepSpeed' - script: | From 60063f294f347a085384767a61aa023f0ac9b00d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 16:31:15 +0000 Subject: [PATCH 44/58] Fix error in optimizer step logic --- pytorch_lightning/plugins/precision/deepspeed_precision.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 5a249cf122818..711ede2f7ded4 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -24,9 +24,10 @@ def pre_optimizer_step( # DeepSpeed not support closures. lambda_closure() - if pl_module.automatic_optimization: + if not pl_module.automatic_optimization: pl_module.trainer.call_hook("on_after_backward") - deepspeed_engine.step() + + deepspeed_engine.step() return False From 5aa9acc3fda913ada231ae4c0c7b34f037db08e7 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 16:32:44 +0000 Subject: [PATCH 45/58] Attempt to fix symbolic link --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 70e260d717945..6eb2af90814c0 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -65,7 +65,7 @@ jobs: # Temporary fix till DeepSpeed release pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb # Update path to find ninja installation - ln -s /usr/bin/ninja /home/AzDevOps_azpcontainer/.local/bin/ninja + ln -s /home/AzDevOps_azpcontainer/.local/bin/ninja /usr/bin/ninja export PATH=$PATH:/home/AzDevOps_azpcontainer/.local/bin echo $PATH whereis ninja From b68a5396beccdf5c1ff8d2c100a3b91f2d960c0b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 16:52:37 +0000 Subject: [PATCH 46/58] Reverse to aid debugging --- azure-pipelines.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 6eb2af90814c0..3ab013ee2020a 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -65,7 +65,6 @@ jobs: # Temporary fix till DeepSpeed release pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb # Update path to find ninja installation - ln -s /home/AzDevOps_azpcontainer/.local/bin/ninja /usr/bin/ninja export PATH=$PATH:/home/AzDevOps_azpcontainer/.local/bin echo $PATH whereis ninja @@ -82,14 +81,15 @@ jobs: ls -l legacy/checkpoints/ displayName: 'Get legacy checkpoints' + - bash: | + which ninja + sh tests/special_tests.sh + displayName: 'Testing: special' + - script: | python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 displayName: 'Testing: standard' - - script: | - sh tests/special_tests.sh - displayName: 'Testing: special' - - bash: | python -m coverage report python -m coverage xml From 087892780b9f6af12ffeb0488863804555b49834 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 16:55:44 +0000 Subject: [PATCH 47/58] Export path again --- azure-pipelines.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 3ab013ee2020a..722f27657728a 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -82,6 +82,7 @@ jobs: displayName: 'Get legacy checkpoints' - bash: | + export PATH=$PATH:/home/AzDevOps_azpcontainer/.local/bin which ninja sh tests/special_tests.sh displayName: 'Testing: special' From 7dd17d3d4e9494d10fd51865d9dad0528b892c07 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 17:02:22 +0000 Subject: [PATCH 48/58] Clean up mess --- azure-pipelines.yml | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 722f27657728a..cf2fcf643c0ec 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -64,11 +64,6 @@ jobs: - bash: | # Temporary fix till DeepSpeed release pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb - # Update path to find ninja installation - export PATH=$PATH:/home/AzDevOps_azpcontainer/.local/bin - echo $PATH - whereis ninja - which ninja displayName: 'Install DeepSpeed' - script: | @@ -81,16 +76,16 @@ jobs: ls -l legacy/checkpoints/ displayName: 'Get legacy checkpoints' + - script: | + python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 + displayName: 'Testing: standard' + - bash: | + # Required for Ninja binary for building extensions, which is installed at this location export PATH=$PATH:/home/AzDevOps_azpcontainer/.local/bin - which ninja sh tests/special_tests.sh displayName: 'Testing: special' - - script: | - python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 - displayName: 'Testing: standard' - - bash: | python -m coverage report python -m coverage xml From 3450eaca1682e35bd912c2824b5c88a7e6f0b0ca Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 17 Feb 2021 19:05:44 +0100 Subject: [PATCH 49/58] var --- azure-pipelines.yml | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index cf2fcf643c0ec..19eaa8232b24e 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -19,6 +19,9 @@ jobs: # how much time to give 'run always even if cancelled tasks' before stopping them cancelTimeoutInMinutes: 2 + variables: + PATH: "$PATH:/home/AzDevOps_azpcontainer/:local/bin:" + pool: dsvm-spot-pool #strategy: @@ -76,16 +79,15 @@ jobs: ls -l legacy/checkpoints/ displayName: 'Get legacy checkpoints' - - script: | - python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 - displayName: 'Testing: standard' - - bash: | - # Required for Ninja binary for building extensions, which is installed at this location - export PATH=$PATH:/home/AzDevOps_azpcontainer/.local/bin + # FIXME: move it after standard sh tests/special_tests.sh displayName: 'Testing: special' + - script: | + python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 + displayName: 'Testing: standard' + - bash: | python -m coverage report python -m coverage xml From 0b9e7d54ba7295e803f7fff74a7752a68d388f3b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 18:15:04 +0000 Subject: [PATCH 50/58] Revert "var" This reverts commit 3450eaca --- azure-pipelines.yml | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 19eaa8232b24e..cf2fcf643c0ec 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -19,9 +19,6 @@ jobs: # how much time to give 'run always even if cancelled tasks' before stopping them cancelTimeoutInMinutes: 2 - variables: - PATH: "$PATH:/home/AzDevOps_azpcontainer/:local/bin:" - pool: dsvm-spot-pool #strategy: @@ -79,15 +76,16 @@ jobs: ls -l legacy/checkpoints/ displayName: 'Get legacy checkpoints' - - bash: | - # FIXME: move it after standard - sh tests/special_tests.sh - displayName: 'Testing: special' - - script: | python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 displayName: 'Testing: standard' + - bash: | + # Required for Ninja binary for building extensions, which is installed at this location + export PATH=$PATH:/home/AzDevOps_azpcontainer/.local/bin + sh tests/special_tests.sh + displayName: 'Testing: special' + - bash: | python -m coverage report python -m coverage xml From e0a2d6bb085c8dc79ebec7b409dc97d97e9d0a29 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 18:18:26 +0000 Subject: [PATCH 51/58] Address review, add todo --- azure-pipelines.yml | 2 +- pytorch_lightning/accelerators/accelerator.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index cf2fcf643c0ec..56c4b7120b727 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -62,7 +62,7 @@ jobs: displayName: 'Install dependencies' - bash: | - # Temporary fix till DeepSpeed release + # Temporary fix till DeepSpeed release, move this into CUDA image pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb displayName: 'Install DeepSpeed' diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 36bfba47a34b8..40b39332cda31 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -314,7 +314,7 @@ def setup_optimizers(self, trainer: "Trainer"): trainer: the Trainer, these optimizers should be connected to model: the model to be optimized by the created optimizers """ - if trainer.testing is True: + if trainer.testing: return optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers( trainer=trainer, model=self.lightning_module From c565bb11ffad4590077b960e458a546472dda9df Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 18:36:21 +0000 Subject: [PATCH 52/58] Add note about unsupported functionality --- docs/source/advanced/multi_gpu.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index cd3e9c72b3cb3..fe094189b5aa9 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -691,6 +691,9 @@ To use DeepSpeed, you first need to install DeepSpeed using the commands below. If you run into an issue with the install or later in training, ensure that the CUDA version of the pytorch you've installed matches your locally installed CUDA (you can see which one has been recognized by running ``nvcc --version``). Additionally if you run into any issues installing m4py, ensure you have openmpi installed using ``sudo apt install libopenmpi-dev`` or ``brew install mpich`` before running ``pip install mpi4py``. +.. note:: + Currently ``resume_from_checkpoint`` and manual optimization are not supported. + ZeRO-Offload """""""""""" From fbaf86f32b13b72f1e2f210c9c234b8800b475a6 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Wed, 17 Feb 2021 18:59:29 +0000 Subject: [PATCH 53/58] Update docs/source/advanced/multi_gpu.rst MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- docs/source/advanced/multi_gpu.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 5c87be51fda06..5f766268696a8 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -816,7 +816,7 @@ We support taking the config as a json formatted file: trainer.fit(model) -You can use also use an environment variable via your Pytorch Lightning script: +You can use also use an environment variable via your PyTorch Lightning script: .. code-block:: bash From 1e7dcd62f62c1acb7327bc5514a65d4d72267608 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 19:04:41 +0000 Subject: [PATCH 54/58] Address review --- tests/plugins/test_deepspeed_plugin.py | 59 ++++++-------------------- 1 file changed, 12 insertions(+), 47 deletions(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 9f6f61ca34c07..5a32293058240 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -41,22 +41,13 @@ def test_deepspeed_plugin_string(tmpdir): Test to ensure that the plugin can be passed via string, and parallel devices is correctly set. """ - class CB(Callback): - - def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) - assert trainer.accelerator_backend.training_type_plugin.parallel_devices == [torch.device('cpu')] - raise SystemExit() - - model = BoringModel() trainer = Trainer( fast_dev_run=True, plugins='deepspeed', - callbacks=[CB()], ) - with pytest.raises(SystemExit): - trainer.fit(model) + assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) + assert trainer.accelerator_backend.training_type_plugin.parallel_devices == [torch.device('cpu')] @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") @@ -65,22 +56,13 @@ def test_deepspeed_plugin(tmpdir): Test to ensure that the plugin can be passed directly, and parallel devices is correctly set. """ - class CB(Callback): - - def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) - assert trainer.accelerator_backend.training_type_plugin.parallel_devices == [torch.device('cpu')] - raise SystemExit() - - model = BoringModel() trainer = Trainer( fast_dev_run=True, plugins=[DeepSpeedPlugin()], - callbacks=[CB()], ) - with pytest.raises(SystemExit): - trainer.fit(model) + assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) + assert trainer.accelerator_backend.training_type_plugin.parallel_devices == [torch.device('cpu')] @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") @@ -93,24 +75,15 @@ def test_deepspeed_plugin_env(tmpdir, monkeypatch, deepspeed_config): f.write(json.dumps(deepspeed_config)) monkeypatch.setenv("PL_DEEPSPEED_CONFIG_PATH", config_path) - class CB(Callback): - - def on_fit_start(self, trainer, pl_module): - plugin = trainer.accelerator_backend.training_type_plugin - assert isinstance(plugin, DeepSpeedPlugin) - assert plugin.parallel_devices == [torch.device('cpu')] - assert plugin.config == deepspeed_config - raise SystemExit() - - model = BoringModel() trainer = Trainer( fast_dev_run=True, plugins='deepspeed', - callbacks=[CB()], ) - with pytest.raises(SystemExit): - trainer.fit(model) + plugin = trainer.accelerator_backend.training_type_plugin + assert isinstance(plugin, DeepSpeedPlugin) + assert plugin.parallel_devices == [torch.device('cpu')] + assert plugin.config == deepspeed_config @pytest.mark.parametrize( @@ -127,19 +100,11 @@ def test_deepspeed_precision_choice(amp_backend, tmpdir): DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin """ - class CB(Callback): + trainer = Trainer(fast_dev_run=True, plugins='deepspeed', amp_backend=amp_backend, precision=16) - def on_fit_start(self, trainer, pl_module): - assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) - assert isinstance(trainer.accelerator_backend.precision_plugin, DeepSpeedPrecisionPlugin) - assert trainer.accelerator_backend.precision_plugin.precision == 16 - raise SystemExit() - - model = BoringModel() - trainer = Trainer(fast_dev_run=True, plugins='deepspeed', callbacks=[CB()], amp_backend=amp_backend, precision=16) - - with pytest.raises(SystemExit): - trainer.fit(model) + assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) + assert isinstance(trainer.accelerator_backend.precision_plugin, DeepSpeedPrecisionPlugin) + assert trainer.accelerator_backend.precision_plugin.precision == 16 @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") From ea1d78c72d2e1599457bbb466daa85f6c263a11a Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 19:05:13 +0000 Subject: [PATCH 55/58] Remove import --- tests/plugins/test_deepspeed_plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 5a32293058240..55bf990f0883e 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -7,7 +7,6 @@ from torch.optim import Optimizer from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin from pytorch_lightning.utilities import _APEX_AVAILABLE, _DEEPSPEED_AVAILABLE, _NATIVE_AMP_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException From 10da87e3043c5b23d201a4205810ae5ce9ee6b24 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 19:08:16 +0000 Subject: [PATCH 56/58] Add tmpdir --- tests/plugins/test_deepspeed_plugin.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 55bf990f0883e..1d25c529dd963 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -42,6 +42,7 @@ def test_deepspeed_plugin_string(tmpdir): trainer = Trainer( fast_dev_run=True, + default_root_dir=tmpdir, plugins='deepspeed', ) @@ -57,6 +58,7 @@ def test_deepspeed_plugin(tmpdir): trainer = Trainer( fast_dev_run=True, + default_root_dir=tmpdir, plugins=[DeepSpeedPlugin()], ) @@ -76,6 +78,7 @@ def test_deepspeed_plugin_env(tmpdir, monkeypatch, deepspeed_config): trainer = Trainer( fast_dev_run=True, + default_root_dir=tmpdir, plugins='deepspeed', ) @@ -99,7 +102,9 @@ def test_deepspeed_precision_choice(amp_backend, tmpdir): DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin """ - trainer = Trainer(fast_dev_run=True, plugins='deepspeed', amp_backend=amp_backend, precision=16) + trainer = Trainer( + fast_dev_run=True, default_root_dir=tmpdir, plugins='deepspeed', amp_backend=amp_backend, precision=16 + ) assert isinstance(trainer.accelerator_backend.training_type_plugin, DeepSpeedPlugin) assert isinstance(trainer.accelerator_backend.precision_plugin, DeepSpeedPrecisionPlugin) @@ -149,6 +154,7 @@ def test_invalid_deepspeed_defaults_no_precision(tmpdir): model = BoringModel() trainer = Trainer( fast_dev_run=True, + default_root_dir=tmpdir, plugins='deepspeed', ) with pytest.raises( @@ -175,6 +181,7 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args model = TestModel() trainer = Trainer( fast_dev_run=True, + default_root_dir=tmpdir, plugins=DeepSpeedPlugin(zero_optimization=False), gpus=1, ) @@ -204,6 +211,7 @@ def on_train_start(self) -> None: model = TestModel() trainer = Trainer( plugins=DeepSpeedPlugin(zero_optimization=False), + default_root_dir=tmpdir, gpus=1, fast_dev_run=True, ) @@ -236,6 +244,7 @@ def on_train_start(self) -> None: model = TestModel() trainer = Trainer( plugins=[DeepSpeedPlugin(config=deepspeed_config)], + default_root_dir=tmpdir, gpus=1, fast_dev_run=True, ) @@ -259,6 +268,7 @@ def test_deepspeed_multigpu(tmpdir, deepspeed_config): model = BoringModel() trainer = Trainer( plugins=[DeepSpeedPlugin(zero_optimization=False)], + default_root_dir=tmpdir, gpus=2, fast_dev_run=True, precision=16, From f789b77a7641203db0936cce43be239c34aa8c7a Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 19:18:57 +0000 Subject: [PATCH 57/58] Add note --- docs/source/advanced/multi_gpu.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 5f766268696a8..04381f42e0159 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -707,6 +707,7 @@ Additionally if you run into any issues installing m4py, ensure you have openmpi .. note:: Currently ``resume_from_checkpoint`` and manual optimization are not supported. + DeepSpeed only supports single optimizer, single scheduler. ZeRO-Offload """""""""""" From 31d6267547cd4306d5c3fa4bed74a930d7250c0c Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 17 Feb 2021 19:19:04 +0000 Subject: [PATCH 58/58] Add note --- docs/source/advanced/multi_gpu.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 04381f42e0159..bb54f97706687 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -707,6 +707,7 @@ Additionally if you run into any issues installing m4py, ensure you have openmpi .. note:: Currently ``resume_from_checkpoint`` and manual optimization are not supported. + DeepSpeed only supports single optimizer, single scheduler. ZeRO-Offload