Skip to content

Commit

Permalink
Fix ShardedDataParallel has no attribute require_backward_grad_sync (#…
Browse files Browse the repository at this point in the history
…6915)

Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
  • Loading branch information
awaelchli and kaushikb11 authored Apr 10, 2021
1 parent 20ff50c commit fe0d088
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))


- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))


## [1.2.7] - 2021-04-06

### Fixed
Expand Down
10 changes: 10 additions & 0 deletions pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.
from typing import Optional

import torch
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
Expand All @@ -33,6 +36,7 @@ def configure_ddp(self):
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
)
setattr(self._model, "require_backward_grad_sync", False)

def _reinit_optimizers_with_oss(self):
optimizers = self.lightning_module.trainer.optimizers
Expand Down Expand Up @@ -70,3 +74,9 @@ def _optim_state_dict(self, optimizer):
@property
def lightning_module(self) -> LightningModule:
return unwrap_lightning_module_sharded(self._model)

def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
pass

def post_training_step(self):
pass
10 changes: 10 additions & 0 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.
from typing import Optional

import torch
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
Expand All @@ -32,6 +35,7 @@ def configure_ddp(self):
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
)
setattr(self._model, "require_backward_grad_sync", False)

def _reinit_optimizers_with_oss(self):
optimizers = self.lightning_module.trainer.optimizers
Expand Down Expand Up @@ -65,3 +69,9 @@ def _optim_state_dict(self, optimizer):
@property
def lightning_module(self) -> LightningModule:
return unwrap_lightning_module_sharded(self._model)

def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
pass

def post_training_step(self):
pass
29 changes: 29 additions & 0 deletions tests/plugins/test_sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,32 @@ def test_ddp_sharded_plugin_test_multigpu(tmpdir, trainer_kwargs):

trainer.validate(model)
trainer.test(model)


class ManualBoringModel(BoringModel):

def __init__(self):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
output = self(batch)
loss = self.loss(batch, output)
self.manual_backward(loss)
opt.step()
return {"loss": loss}


@RunIf(skip_windows=True, special=True, fairscale=True, min_gpus=2)
@pytest.mark.parametrize("accelerator", ["ddp_sharded", "ddp_sharded_spawn"])
def test_ddp_sharded_plugin_manual_optimization(tmpdir, accelerator):
model = ManualBoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
accelerator=accelerator,
fast_dev_run=2,
gpus=2,
)
trainer.fit(model)

0 comments on commit fe0d088

Please sign in to comment.