From 0f4b790d3637eb15717d9524d954dffa2737ad20 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 18 Oct 2021 14:57:20 +0100 Subject: [PATCH] Hack to allow deepspeed to run fp16 --- pytorch_lightning/lite/lite.py | 36 +++++++++++++------ .../plugins/training_type/deepspeed.py | 15 ++++---- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 76d5a6607c8e2..6ede74f817157 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -12,26 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod from collections import Callable from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Optional, Sequence, Union, List, Dict, Tuple, Generator +from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn from torch import Tensor from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler, RandomSampler, Sampler +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler from pytorch_lightning import Trainer from pytorch_lightning.accelerators import Accelerator, TPUAccelerator -from pytorch_lightning.lite.wrappers import _LiteOptimizer, _LiteModule, _LiteDataLoader -from pytorch_lightning.plugins import PLUGIN_INPUT, DDPSpawnPlugin, TrainingTypePlugin, DeepSpeedPlugin +from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer +from pytorch_lightning.plugins import DDPSpawnPlugin, DeepSpeedPlugin, PLUGIN_INPUT, TrainingTypePlugin from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin -from pytorch_lightning.utilities import move_data_to_device, DistributedType, DeviceType +from pytorch_lightning.utilities import DeviceType, DistributedType, move_data_to_device from pytorch_lightning.utilities.data import has_iterable_dataset from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -102,7 +102,10 @@ def __init__( @property def device(self) -> torch.device: - """The current device this process runs on. Use this to create tensors directly on the device if needed.""" + """The current device this process runs on. + + Use this to create tensors directly on the device if needed. + """ return self._accelerator.root_device @property @@ -233,8 +236,8 @@ def backward(self, tensor: Tensor, *args: Any, **kwargs: Any) -> None: def cast(self) -> Generator[None, None, None]: """A context manager to automatically convert operations for the chosen precision. - Use this only if the `forward` method of your model does not cover all operations you wish to run with - the chosen precision setting. + Use this only if the `forward` method of your model does not cover all operations you wish to run with the + chosen precision setting. """ with self._accelerator.forward_context(): yield @@ -255,8 +258,10 @@ def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tens return move_data_to_device(obj, device=self.device) def print(self, *args: Any, **kwargs: Any) -> None: - """Print something only on the first process. Arguments passed to this method are forwarded to the - Python built-in :func:`print` function.""" + """Print something only on the first process. + + Arguments passed to this method are forwarded to the Python built-in :func:`print` function. + """ if self.local_rank == 0: print(*args, **kwargs) @@ -291,6 +296,9 @@ def _run_wrapper(self, run_method: Callable) -> Callable: return partial(self._run_impl, run_method) def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> None: + if isinstance(self._strategy, DeepSpeedPlugin): + # todo: this is a hack as deepspeed currently relies on the precision plugin + self._set_deepspeed_precision_variables() self._strategy.setup_environment() if isinstance(self._strategy, DDPSpawnPlugin): self._strategy.spawn(run_method, *args, **kwargs) @@ -298,6 +306,12 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> None: run_method(*args, **kwargs) # TODO: any teardown needed here? + def _set_deepspeed_precision_variables(self): + amp_type = self._accelerator_connector.amp_type + amp_level = self._accelerator_connector.amp_level + precision = self._accelerator_connector.precision + self._strategy.amp_level, self._strategy.amp_type, self._strategy._precision = amp_level, amp_type, precision + def _setup_model_and_optimizers( self, model: nn.Module, diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 07ea169f7d5ff..ca2e279ddbba3 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -325,6 +325,10 @@ def __init__( self.hysteresis = hysteresis self.min_loss_scale = min_loss_scale + self._precision = None + self.amp_level = None + self.amp_type = None + def _load_config(self, config): if config is None and self.DEEPSPEED_ENV_VAR in os.environ: rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") @@ -516,7 +520,7 @@ def model_sharded_context(self) -> Generator[None, None, None]: @property def precision(self) -> Union[str, int]: - return self.lightning_module.trainer.precision + return self._precision or self.lightning_module.trainer.precision def _set_deepspeed_activation_checkpointing(self): if self.config.get("activation_checkpointing"): @@ -633,11 +637,10 @@ def _auto_select_batch_size(self): return batch_size def _format_precision_config(self): - # TODO: support precision - return - amp_type = self.lightning_module.trainer.accelerator_connector.amp_type - amp_level = self.lightning_module.trainer.accelerator_connector.amp_level - precision = self.lightning_module.trainer.accelerator_connector.precision + amp_type = self.amp_type or self.lightning_module.trainer.accelerator_connector.amp_type + precision = self.precision or self.lightning_module.trainer.accelerator_connector.precision + if amp_type == AMPType.APEX: + amp_level = self.amp_level or self.lightning_module.trainer.accelerator_connector.amp_level if precision in (16, "mixed"): if "fp16" not in self.config and amp_type == AMPType.NATIVE: # FP16 is a DeepSpeed standalone AMP implementation