Skip to content

Commit 9dd78dd

Browse files
tchatonSeanNarenSeanNarenBorda
authored andcommitted
[bug-fix] DDP and automatic_optimization=False (#4485)
* resolve bug * add self._running_manual_optim * update * update tests * update lightning module * resolve bug * update tests * update * resolve pep8 * update * replace by `ddp_spawn` * temporary fix * update * update * move update to training_loop * make both ddp_spawn * introduce `manual_optimizer_step` * update changelog * added changelog wrong place * add force_optimizer_step * update docstring for tests * update optimizer_step * update zero_grad * resolve flake8 * move update into manual_optimizer_step * add zero_grad * remove zero_grad tests * remove manual_backward in AMP, it doesn't help * update * loosen tests * update * update doc * add TODO * Removed unnecessary get model from native amp * Remove try except with pytest raise * Add seed, clean up imports, remove try catch to reproduce error * update code * update test * revert back * formatting * Update pytorch_lightning/core/lightning.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: SeanNaren <sean@grid.ai> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
1 parent b9cfa4e commit 9dd78dd

File tree

9 files changed

+366
-23
lines changed

9 files changed

+366
-23
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ timit_data/
3333
.Python
3434
ide_layouts/
3535
build/
36+
_build/
3637
develop-eggs/
3738
dist/
3839
downloads/

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3333
- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))
3434

3535

36+
- Added `manual_optimizer_step` which work with `AMP Native` and `accumulated_grad_batches` ([#4485](https://github.com/PyTorchLightning/pytorch-lightning/pull/4485))
37+
38+
3639
- Added `persistent(mode)` method to metrics, to enable and disable metric states being added to `state_dict` ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482))
3740

3841

docs/source/lightning_module.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,12 @@ manual_backward
10091009
.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward
10101010
:noindex:
10111011

1012+
manual_optimizer_step
1013+
~~~~~~~~~~~~~~~~~~~~~
1014+
1015+
.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_optimizer_step
1016+
:noindex:
1017+
10121018
on_after_backward
10131019
~~~~~~~~~~~~~~~~~
10141020

docs/source/optimizers.rst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,16 @@ to manually manage the optimization process. To do so, do the following:
3636
3737
# use self.backward which will also handle scaling the loss when using amp
3838
self.manual_backward(loss_a, opt_g)
39-
opt_g.step()
40-
opt_g.zero_grad()
39+
self.manual_optimizer_step(opt_g)
40+
4141
4242
# do anything you want
4343
loss_b = ...
4444
4545
# pass in any args that loss.backward() normally takes
4646
self.manual_backward(loss_b, opt_d, retain_graph=True)
4747
self.manual_backward(loss_b, opt_d)
48-
opt_d.step()
49-
opt_d.zero_grad()
48+
self.manual_optimizer_step(opt_d)
5049
5150
# log losses
5251
self.log('loss_a', loss_a)

pytorch_lightning/accelerators/accelerator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,11 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
109109
def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
110110
model_ref = self.trainer.get_model()
111111
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
112-
native_amp = self.trainer.amp_backend == AMPType.NATIVE
112+
using_native_amp = self.trainer.amp_backend == AMPType.NATIVE
113+
automatic_optimization = self.trainer.train_loop.automatic_optimization
113114

114115
# native amp + lbfgs is a no go right now
115-
if native_amp and is_lbfgs:
116+
if using_native_amp and is_lbfgs:
116117
raise MisconfigurationException(
117118
'native PyTorch amp and lbfgs are not compatible.'
118119
' To request, please file a Github issue in PyTorch and tag @mcarilli')
@@ -125,12 +126,12 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
125126
optimizer_idx=opt_idx,
126127
optimizer_closure=lambda_closure,
127128
on_tpu=False, # TPUAccelerator class sets this as True
128-
using_native_amp=native_amp,
129+
using_native_amp=using_native_amp,
129130
using_lbfgs=is_lbfgs
130131
)
131132

132133
# scale when native amp
133-
if native_amp:
134+
if automatic_optimization and using_native_amp:
134135
self.trainer.scaler.update()
135136

136137
def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):

pytorch_lightning/core/lightning.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(self, *args, **kwargs):
111111
self._datamodule = None
112112
self._results: Optional[Result] = None
113113
self._current_fx_name = ''
114+
self._running_manual_backward = False
114115
self._current_hook_fx_name = None
115116
self._current_dataloader_idx = None
116117

@@ -1085,19 +1086,65 @@ def manual_backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -
10851086
10861087
.. tip:: In manual mode we still automatically clip grads if Trainer(gradient_clip_val=x) is set
10871088
1089+
.. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set
1090+
and you use `model.manual_optimizer_step(optimizer)`
1091+
10881092
Example::
10891093
10901094
def training_step(...):
10911095
(opt_a, opt_b) = self.optimizers()
10921096
loss = ...
10931097
# automatically applies scaling, etc...
10941098
self.manual_backward(loss, opt_a)
1099+
self.manual_optimizer_step(opt_a)
10951100
"""
10961101
# make sure we're using manual opt
10971102
self._verify_is_manual_optimization('manual_backward')
10981103

10991104
# backward
1105+
self._running_manual_backward = True
11001106
self.trainer.train_loop.backward(loss, optimizer, -1, *args, **kwargs)
1107+
self._running_manual_backward = False
1108+
1109+
def manual_optimizer_step(self, optimizer: Optimizer, force_optimizer_step:bool = False) -> None:
1110+
"""
1111+
Call this directly from your training_step when doing optimizations manually.
1112+
By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you
1113+
1114+
.. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set.
1115+
1116+
Args:
1117+
optimizer: Optimizer used to perform `.step()` call
1118+
1119+
force_optimizer_step: Whether to force an optimizer step. Could be useful when having 2 optimizers
1120+
and one should use accumulated gradients but not the other one.
1121+
One could put its own logic to force an optimizer step.
1122+
1123+
Example::
1124+
1125+
def training_step(...):
1126+
(opt_a, opt_b) = self.optimizers()
1127+
loss = ...
1128+
# automatically applies scaling, etc...
1129+
self.manual_backward(loss, opt_a)
1130+
# This will force an opt.step() even if accumulate_grad_batches is set.
1131+
self.manual_optimizer_step(opt_a, force_optimizer_step=True)
1132+
1133+
"""
1134+
# make sure we're using manual opt
1135+
self._verify_is_manual_optimization('manual_optimizer_step')
1136+
1137+
if not self.trainer.train_loop.should_accumulate() or force_optimizer_step:
1138+
1139+
# mock closure function as the user is responsible to call `manual_backward`
1140+
def mock_optimizer_closure():
1141+
return
1142+
1143+
self.trainer.train_loop.optimizer_step(optimizer, None, self.trainer.batch_idx, mock_optimizer_closure)
1144+
1145+
# update will be called after every optimizer_step call
1146+
if self.trainer.amp_backend == AMPType.NATIVE:
1147+
self.trainer.scaler.update()
11011148

11021149
def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
11031150
"""
@@ -1118,7 +1165,8 @@ def backward(self, loss, optimizer, optimizer_idx):
11181165
loss.backward()
11191166
11201167
"""
1121-
loss.backward(*args, **kwargs)
1168+
if self.trainer.train_loop.automatic_optimization or self._running_manual_backward:
1169+
loss.backward(*args, **kwargs)
11221170

11231171
def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
11241172
"""

pytorch_lightning/trainer/training_loop.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,12 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
306306
# when in dev debugging track the losses
307307
self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach())
308308

309+
def _check_training_step_output(self, training_step_output):
310+
if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization:
311+
if training_step_output.grad_fn is None:
312+
# TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
313+
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
314+
309315
def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
310316
# give the PL module a result for logging
311317
model_ref = self.trainer.get_model()
@@ -318,6 +324,8 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
318324
training_step_output = self.trainer.accelerator_backend.training_step(args)
319325
self.trainer.logger_connector.cache_logged_metrics()
320326

327+
self._check_training_step_output(training_step_output)
328+
321329
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
322330

323331
training_step_output_for_epoch_end, training_step_output = self._process_training_step_output(
@@ -690,6 +698,8 @@ def train_step_and_backward_closure():
690698

691699
if self._curr_step_result is None:
692700
# user decided to skip optimization
701+
# make sure to zero grad.
702+
self.zero_grad_handler(batch_idx, optimizer, opt_idx)
693703
continue
694704

695705
batch_outputs = self._process_closure_result(
@@ -701,11 +711,8 @@ def train_step_and_backward_closure():
701711
grad_norm_dic = self._cur_grad_norm_dict
702712
self._cur_grad_norm_dict = None
703713

704-
# hook
705-
self.on_before_zero_grad(optimizer)
706-
707-
# clear gradients
708-
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
714+
# hook + clear gradients
715+
self.zero_grad_handler(batch_idx, optimizer, opt_idx)
709716

710717
# update running loss + reset accumulated loss
711718
self.update_running_loss()
@@ -929,3 +936,14 @@ def update_running_loss(self):
929936

930937
# reset for next set of accumulated grads
931938
self.accumulated_loss.reset()
939+
940+
def zero_grad_handler(self, batch_idx, optimizer, opt_idx):
941+
if self.automatic_optimization:
942+
# hook
943+
self.on_before_zero_grad(optimizer)
944+
optimizers = enumerate([optimizer])
945+
else:
946+
optimizers = self.get_optimizers_iterable()
947+
948+
for idx, optimizer in optimizers:
949+
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)

0 commit comments

Comments
 (0)