From 655ad28582e9b9f880ac4fc022cf8fecec1b42d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 03:54:39 +0200 Subject: [PATCH 1/8] update setup in deepspeed --- .../plugins/training_type/deepspeed.py | 47 ++++++++++++++----- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index e2e8c316f48d1..28f311bdb9c87 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -22,7 +22,9 @@ from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union import torch +from torch.nn import Module from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase @@ -377,6 +379,35 @@ def pre_dispatch(self): self.init_deepspeed() self.barrier() + def setup_models_and_optimizers( + self, models: List[Module], optimizers: List[Optimizer] + ) -> Tuple[List[Module], List[Optimizer]]: + if not (len(models) == len(optimizers) == 1): + raise ValueError( + f"Currently only one model and one optimizer is supported with DeepSpeed." + f" Got {len(models)} models and {len(optimizers)} optimizers instead." + ) + + self.config["train_micro_batch_size_per_gpu"] = 1 + self._model, optimizer = self._setup_model_and_optimizer(models[0], optimizers[0]) + self._set_deepspeed_activation_checkpointing() + return [self._model], [optimizer] + + def _setup_model_and_optimizer( + self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None + ): + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize( + args=argparse.Namespace(device_rank=self.root_device.index), + config=self.config, + model=model, + model_parameters=model_parameters, # type: ignore + optimizer=optimizer, + lr_scheduler=lr_scheduler, + dist_init_required=False, + ) + return deepspeed_engine, deepspeed_optimizer + def init_deepspeed(self): # check that `configure_gradient_clipping` hook isn't overriden since deepspeed handles # gradient clipping internally @@ -441,18 +472,7 @@ def _initialize_deepspeed_train(self, model): optimizer, lr_scheduler, _ = self._init_optimizers() scheduler = lr_scheduler["scheduler"] - - model_parameters = filter(lambda p: p.requires_grad, self.model.parameters()) - model, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize( - args=argparse.Namespace(device_rank=self.root_device.index), - config=self.config, - model=model, - model_parameters=model_parameters, - optimizer=optimizer, - lr_scheduler=scheduler, - dist_init_required=False, - ) - + model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler) self._set_deepspeed_activation_checkpointing() # although we set these here, deepspeed manages the specific optimizer logic @@ -568,6 +588,9 @@ def _format_config(self): self._format_precision_config() def _format_batch_size_and_grad_accum_config(self): + if self.lightning_module is None: + return + if "gradient_accumulation_steps" in self.config: raise MisconfigurationException( "Do not set `gradient_accumulation_steps` in the DeepSpeed config" From 81172ec96999ba7840917adee8b8eef80ff93d8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 03:59:02 +0200 Subject: [PATCH 2/8] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 888d22a520f75..7d465e2088059 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -201,7 +201,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - LightningLite: * Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988)) - + * Implemented `setup_models_and_optimizers` for DeepSpeed ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009)) ### Changed From b53fc7e6820c451d014f4a93becd0e0990e3dec3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 11:44:37 +0200 Subject: [PATCH 3/8] Update pytorch_lightning/plugins/training_type/deepspeed.py Co-authored-by: Sean Naren --- pytorch_lightning/plugins/training_type/deepspeed.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 28f311bdb9c87..9a64a700e3cff 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -588,6 +588,7 @@ def _format_config(self): self._format_precision_config() def _format_batch_size_and_grad_accum_config(self): + # todo: using lite, we do not support these variables within the config if self.lightning_module is None: return From 1b3d51d4730ebbde4892094652d6a52e55507307 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 17:52:20 +0200 Subject: [PATCH 4/8] add docs --- pytorch_lightning/plugins/training_type/deepspeed.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index b98185db09e59..efd0812439d79 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -382,6 +382,14 @@ def pre_dispatch(self): def _setup_models_and_optimizers( self, models: List[Module], optimizers: List[Optimizer] ) -> Tuple[List[Module], List[Optimizer]]: + """Setup multiple models and multiple optimizers together. + + Currently only one model paired with a single optimizer is supported. + + Return: + A list with one model wrapped into a :class:`deepspeed.DeepSpeedEngine` and list with a single + deepspeed optimizer. + """ if not (len(models) == len(optimizers) == 1): raise ValueError( f"Currently only one model and one optimizer is supported with DeepSpeed." @@ -396,6 +404,8 @@ def _setup_models_and_optimizers( def _setup_model_and_optimizer( self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None ): + """Initialize one model and one optimizer with an optional learning rate scheduler. This calls + :func:`deepspeed.initialize` internally.""" model_parameters = filter(lambda p: p.requires_grad, model.parameters()) deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize( args=argparse.Namespace(device_rank=self.root_device.index), From 52f765a2cfc63b20771598bb1b4e44a9ef3dee99 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Oct 2021 15:53:47 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/deepspeed.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index efd0812439d79..1b39c2c9907ab 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -404,8 +404,11 @@ def _setup_models_and_optimizers( def _setup_model_and_optimizer( self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None ): - """Initialize one model and one optimizer with an optional learning rate scheduler. This calls - :func:`deepspeed.initialize` internally.""" + """Initialize one model and one optimizer with an optional learning rate scheduler. + + This calls + :func:`deepspeed.initialize` internally. + """ model_parameters = filter(lambda p: p.requires_grad, model.parameters()) deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize( args=argparse.Namespace(device_rank=self.root_device.index), From 7c07b875a8614e872f001fbe9dcf681edba377b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 18:41:19 +0200 Subject: [PATCH 6/8] Update pytorch_lightning/plugins/training_type/deepspeed.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/plugins/training_type/deepspeed.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 1b39c2c9907ab..231f50449db4c 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -406,8 +406,7 @@ def _setup_model_and_optimizer( ): """Initialize one model and one optimizer with an optional learning rate scheduler. - This calls - :func:`deepspeed.initialize` internally. + This calls :func:`deepspeed.initialize` internally. """ model_parameters = filter(lambda p: p.requires_grad, model.parameters()) deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize( From 05c554badd9bddf26802c858bc7f3ee95f6c5636 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 19:25:26 +0200 Subject: [PATCH 7/8] use set default --- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 231f50449db4c..88743347f2b44 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -396,7 +396,7 @@ def _setup_models_and_optimizers( f" Got {len(models)} models and {len(optimizers)} optimizers instead." ) - self.config["train_micro_batch_size_per_gpu"] = 1 + self.config.setdefault("train_micro_batch_size_per_gpu", 1) self._model, optimizer = self._setup_model_and_optimizer(models[0], optimizers[0]) self._set_deepspeed_activation_checkpointing() return [self._model], [optimizer] From 5ca6828244f9eb3cf8e0139691ac175f961b11a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 19:29:01 +0200 Subject: [PATCH 8/8] add comment --- pytorch_lightning/plugins/training_type/deepspeed.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 88743347f2b44..019fd41d5d1cc 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -396,6 +396,9 @@ def _setup_models_and_optimizers( f" Got {len(models)} models and {len(optimizers)} optimizers instead." ) + # train_micro_batch_size_per_gpu is used for throughput logging purposes + # normally we set this to the batch size, but it is not available here unless the user provides it + # as part of the config self.config.setdefault("train_micro_batch_size_per_gpu", 1) self._model, optimizer = self._setup_model_and_optimizer(models[0], optimizers[0]) self._set_deepspeed_activation_checkpointing()