Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug-fix] DDP and automatic_optimization=False #4485

Merged
merged 74 commits into from
Nov 10, 2020
Merged
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
d48093b
resolve bug
tchaton Nov 2, 2020
704d100
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 2, 2020
855ec4d
add self._running_manual_optim
tchaton Nov 2, 2020
ec937d5
Merge branch 'bugfix/4444_ddp_and_automatic_optimization=False' of ht…
tchaton Nov 2, 2020
c76ec18
update
tchaton Nov 2, 2020
88f40ea
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 2, 2020
7594c4e
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 2, 2020
32648c3
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 3, 2020
ccacf66
update tests
tchaton Nov 3, 2020
9ecde79
Merge branch 'bugfix/4444_ddp_and_automatic_optimization=False' of ht…
tchaton Nov 3, 2020
df31ec7
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 3, 2020
6378b86
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 3, 2020
d062fd8
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 3, 2020
d98f9a0
update lightning module
tchaton Nov 3, 2020
16359a3
Merge branch 'bugfix/4444_ddp_and_automatic_optimization=False' of ht…
tchaton Nov 3, 2020
63671a5
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 4, 2020
2d6188c
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 4, 2020
bde51ab
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 4, 2020
e858f28
resolve bug
tchaton Nov 4, 2020
35be943
update tests
tchaton Nov 4, 2020
4a86cb8
update
tchaton Nov 4, 2020
4b7d8f2
resolve pep8
tchaton Nov 4, 2020
c465ad0
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 4, 2020
6c6d9d4
update
tchaton Nov 4, 2020
6235ca6
Merge branch 'bugfix/4444_ddp_and_automatic_optimization=False' of ht…
tchaton Nov 4, 2020
b1e2c36
replace by `ddp_spawn`
tchaton Nov 4, 2020
a7673d3
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 4, 2020
9fa1ae6
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 5, 2020
6113110
temporary fix
tchaton Nov 5, 2020
a8c6dea
update
tchaton Nov 5, 2020
fdf7f8b
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 5, 2020
3d30d13
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 9, 2020
ad763ff
update
tchaton Nov 9, 2020
b3891bf
move update to training_loop
tchaton Nov 9, 2020
83d8a2f
make both ddp_spawn
tchaton Nov 9, 2020
26f5578
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 9, 2020
fb46cdb
introduce `manual_optimizer_step`
tchaton Nov 9, 2020
23b40f9
update changelog
tchaton Nov 9, 2020
aa22c85
Merge branch 'bugfix/4444_ddp_and_automatic_optimization=False' of ht…
tchaton Nov 9, 2020
aa03ba1
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 9, 2020
442df35
added changelog wrong place
tchaton Nov 9, 2020
c7d2beb
Merge branch 'bugfix/4444_ddp_and_automatic_optimization=False' of ht…
tchaton Nov 9, 2020
d541016
add force_optimizer_step
tchaton Nov 9, 2020
c96bd73
update docstring for tests
tchaton Nov 9, 2020
37156a7
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 9, 2020
2807c68
update optimizer_step
tchaton Nov 9, 2020
31ea8a4
Merge branch 'bugfix/4444_ddp_and_automatic_optimization=False' of ht…
tchaton Nov 9, 2020
557ece3
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 9, 2020
560c024
update zero_grad
tchaton Nov 9, 2020
2321ca3
resolve flake8
tchaton Nov 9, 2020
a0040c8
move update into manual_optimizer_step
tchaton Nov 9, 2020
1cb0516
add zero_grad
tchaton Nov 9, 2020
41e1de9
remove zero_grad tests
tchaton Nov 9, 2020
60e4b38
remove manual_backward in AMP, it doesn't help
tchaton Nov 9, 2020
80c7113
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 9, 2020
d00dbae
update
tchaton Nov 9, 2020
5864c02
loosen tests
tchaton Nov 9, 2020
2715b12
update
tchaton Nov 9, 2020
5cc0fc4
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 10, 2020
3e520e9
update doc
tchaton Nov 10, 2020
ec5ccb0
Merge branch 'bugfix/4444_ddp_and_automatic_optimization=False' of ht…
tchaton Nov 10, 2020
3a0a374
add TODO
tchaton Nov 10, 2020
5d70185
Removed unnecessary get model from native amp
Nov 10, 2020
8a18292
Remove try except with pytest raise
Nov 10, 2020
43a4bb2
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 10, 2020
811e963
Add seed, clean up imports, remove try catch to reproduce error
Nov 10, 2020
70fe4ac
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
tchaton Nov 10, 2020
0e39d80
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
SeanNaren Nov 10, 2020
aec3dde
Merge branch 'master' into bugfix/4444_ddp_and_automatic_optimization…
SeanNaren Nov 10, 2020
466961a
update code
tchaton Nov 10, 2020
9f88fbc
update test
tchaton Nov 10, 2020
d80ad80
revert back
tchaton Nov 10, 2020
da5cc1e
formatting
Borda Nov 10, 2020
b03b4d7
Update pytorch_lightning/core/lightning.py
SeanNaren Nov 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ timit_data/
.Python
ide_layouts/
build/
_build/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how this one is created, well we can keep just curious what process did it :]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this one, so when moving doc from build to _build and serving it with sphinx-serve. They don't get added.

develop-eggs/
dist/
downloads/
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))

- Added `manual_optimizer_step` which work with `AMP Native` and `accumulated_grad_batches` ([#4485](https://github.com/PyTorchLightning/pytorch-lightning/pull/4485))
Borda marked this conversation as resolved.
Show resolved Hide resolved

- Added `persistent(mode)` method to metrics, to enable and disable metric states being added to `state_dict` ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482))

Expand Down
6 changes: 6 additions & 0 deletions docs/source/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,12 @@ manual_backward
.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward
:noindex:

manual_optimizer_step
~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_optimizer_step
:noindex:

on_after_backward
~~~~~~~~~~~~~~~~~

Expand Down
7 changes: 3 additions & 4 deletions docs/source/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,16 @@ to manually manage the optimization process. To do so, do the following:

# use self.backward which will also handle scaling the loss when using amp
self.manual_backward(loss_a, opt_g)
opt_g.step()
opt_g.zero_grad()
self.manual_optimizer_step(opt_g)


# do anything you want
loss_b = ...

# pass in any args that loss.backward() normally takes
self.manual_backward(loss_b, opt_d, retain_graph=True)
self.manual_backward(loss_b, opt_d)
opt_d.step()
opt_d.zero_grad()
self.manual_optimizer_step(opt_d)

# log losses
self.log('loss_a', loss_a)
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
model_ref = self.trainer.get_model()
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
native_amp = self.trainer.amp_backend == AMPType.NATIVE
using_native_amp = self.trainer.amp_backend == AMPType.NATIVE
automatic_optimization = self.trainer.train_loop.automatic_optimization

# native amp + lbfgs is a no go right now
if native_amp and is_lbfgs:
if using_native_amp and is_lbfgs:
raise MisconfigurationException(
'native PyTorch amp and lbfgs are not compatible.'
' To request, please file a Github issue in PyTorch and tag @mcarilli')
Expand All @@ -125,12 +126,12 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
optimizer_idx=opt_idx,
optimizer_closure=lambda_closure,
on_tpu=False, # TPUAccelerator class sets this as True
using_native_amp=native_amp,
using_native_amp=using_native_amp,
using_lbfgs=is_lbfgs
)

# scale when native amp
if native_amp:
if automatic_optimization and using_native_amp:
self.trainer.scaler.update()

def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
Expand Down
53 changes: 52 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(self, *args, **kwargs):
self._datamodule = None
self._results: Optional[Result] = None
self._current_fx_name = ''
self._running_manual_backward = False
self._current_hook_fx_name = None
self._current_dataloader_idx = None

Expand Down Expand Up @@ -1085,19 +1086,68 @@ def manual_backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -

.. tip:: In manual mode we still automatically clip grads if Trainer(gradient_clip_val=x) is set

.. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
and you use `model.manual_optimizer_step(optimizer)`

Example::

def training_step(...):
(opt_a, opt_b) = self.optimizers()
loss = ...
# automatically applies scaling, etc...
self.manual_backward(loss, opt_a)
self.manual_optimizer_step(opt_a)
"""
# make sure we're using manual opt
self._verify_is_manual_optimization('manual_backward')

# backward
self._running_manual_backward = True
self.trainer.train_loop.backward(loss, optimizer, -1, *args, **kwargs)
self._running_manual_backward = False

def manual_optimizer_step(self, optimizer: Optimizer, force_optimizer_step:bool = False) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make sure to forward arguments when we do stuff like this?

ie: now the user can't use the args of .step...

so:

def manual_optimizer_step(self, *args, optimizer: Optimizer, force_optimizer_step:bool = False, **kwargs):
   ... # eventually forwards to:
       optimizer.step(*args, **kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @williamFalcon. I will add a PR for it tomorrow.

"""
Call this directly from your training_step when doing optimizations manually.
By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you

.. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set.

Args:
optimizer: Optimizer used to perform `.step()` call

force_optimizer_step: Whether to force an optimizer step. Could be useful when having 2 optimizers
and one should use accumulated gradients but not the other one.
One could put its own logic to force an optimizer step.

Return:
None
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

Example::

def training_step(...):
(opt_a, opt_b) = self.optimizers()
loss = ...
# automatically applies scaling, etc...
self.manual_backward(loss, opt_a)
# This will force an opt.step() even if accumulate_grad_batches is set.
self.manual_optimizer_step(opt_a, force_optimizer_step=True)

"""
# make sure we're using manual opt
self._verify_is_manual_optimization('manual_optimizer_step')

if not self.trainer.train_loop.should_accumulate() or force_optimizer_step:
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

# mock closure function as the user is responsible to call `manual_backward`
def mock_optimizer_closure():
return

self.trainer.train_loop.optimizer_step(optimizer, None, self.trainer.batch_idx, mock_optimizer_closure)

# update will be called after every optimizer_step call
if self.trainer.amp_backend == AMPType.NATIVE:
self.trainer.scaler.update()

def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
"""
Expand All @@ -1118,7 +1168,8 @@ def backward(self, loss, optimizer, optimizer_idx):
loss.backward()

"""
loss.backward(*args, **kwargs)
if self.trainer.train_loop.automatic_optimization or self._running_manual_backward:
loss.backward(*args, **kwargs)

def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
"""
Expand Down
28 changes: 23 additions & 5 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,12 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
# when in dev debugging track the losses
self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach())

def _check_training_step_output(self, training_step_output):
if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization:
if training_step_output.grad_fn is None:
# TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")

def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
# give the PL module a result for logging
model_ref = self.trainer.get_model()
Expand All @@ -318,6 +324,8 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
training_step_output = self.trainer.accelerator_backend.training_step(args)
self.trainer.logger_connector.cache_logged_metrics()

self._check_training_step_output(training_step_output)

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)

training_step_output_for_epoch_end, training_step_output = self._process_training_step_output(
Expand Down Expand Up @@ -690,6 +698,8 @@ def train_step_and_backward_closure():

if self._curr_step_result is None:
# user decided to skip optimization
# make sure to zero grad.
self.zero_grad_handler(batch_idx, optimizer, opt_idx)
continue

batch_outputs = self._process_closure_result(
Expand All @@ -701,11 +711,8 @@ def train_step_and_backward_closure():
grad_norm_dic = self._cur_grad_norm_dict
self._cur_grad_norm_dict = None

# hook
self.on_before_zero_grad(optimizer)

# clear gradients
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
# hook + clear gradients
self.zero_grad_handler(batch_idx, optimizer, opt_idx)

# update running loss + reset accumulated loss
self.update_running_loss()
Expand Down Expand Up @@ -929,3 +936,14 @@ def update_running_loss(self):

# reset for next set of accumulated grads
self.accumulated_loss.reset()

def zero_grad_handler(self, batch_idx, optimizer, opt_idx):
if self.automatic_optimization:
# hook
self.on_before_zero_grad(optimizer)
optimizers = enumerate([optimizer])
else:
optimizers = self.get_optimizers_iterable()

for idx, optimizer in optimizers:
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
Loading