-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[bug-fix] DDP and automatic_optimization=False #4485
Changes from 69 commits
d48093b
704d100
855ec4d
ec937d5
c76ec18
88f40ea
7594c4e
32648c3
ccacf66
9ecde79
df31ec7
6378b86
d062fd8
d98f9a0
16359a3
63671a5
2d6188c
bde51ab
e858f28
35be943
4a86cb8
4b7d8f2
c465ad0
6c6d9d4
6235ca6
b1e2c36
a7673d3
9fa1ae6
6113110
a8c6dea
fdf7f8b
3d30d13
ad763ff
b3891bf
83d8a2f
26f5578
fb46cdb
23b40f9
aa22c85
aa03ba1
442df35
c7d2beb
d541016
c96bd73
37156a7
2807c68
31ea8a4
557ece3
560c024
2321ca3
a0040c8
1cb0516
41e1de9
60e4b38
80c7113
d00dbae
5864c02
2715b12
5cc0fc4
3e520e9
ec5ccb0
3a0a374
5d70185
8a18292
43a4bb2
811e963
70fe4ac
0e39d80
aec3dde
466961a
9f88fbc
d80ad80
da5cc1e
b03b4d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,7 @@ timit_data/ | |
.Python | ||
ide_layouts/ | ||
build/ | ||
_build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -111,6 +111,7 @@ def __init__(self, *args, **kwargs): | |
self._datamodule = None | ||
self._results: Optional[Result] = None | ||
self._current_fx_name = '' | ||
self._running_manual_backward = False | ||
self._current_hook_fx_name = None | ||
self._current_dataloader_idx = None | ||
|
||
|
@@ -1085,19 +1086,68 @@ def manual_backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) - | |
|
||
.. tip:: In manual mode we still automatically clip grads if Trainer(gradient_clip_val=x) is set | ||
|
||
.. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
and you use `model.manual_optimizer_step(optimizer)` | ||
|
||
Example:: | ||
|
||
def training_step(...): | ||
(opt_a, opt_b) = self.optimizers() | ||
loss = ... | ||
# automatically applies scaling, etc... | ||
self.manual_backward(loss, opt_a) | ||
self.manual_optimizer_step(opt_a) | ||
""" | ||
# make sure we're using manual opt | ||
self._verify_is_manual_optimization('manual_backward') | ||
|
||
# backward | ||
self._running_manual_backward = True | ||
self.trainer.train_loop.backward(loss, optimizer, -1, *args, **kwargs) | ||
self._running_manual_backward = False | ||
|
||
def manual_optimizer_step(self, optimizer: Optimizer, force_optimizer_step:bool = False) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make sure to forward arguments when we do stuff like this? ie: now the user can't use the args of .step... so: def manual_optimizer_step(self, *args, optimizer: Optimizer, force_optimizer_step:bool = False, **kwargs):
... # eventually forwards to:
optimizer.step(*args, **kwargs) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @williamFalcon. I will add a PR for it tomorrow. |
||
""" | ||
Call this directly from your training_step when doing optimizations manually. | ||
By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you | ||
|
||
.. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set. | ||
|
||
Args: | ||
optimizer: Optimizer used to perform `.step()` call | ||
|
||
force_optimizer_step: Whether to force an optimizer step. Could be useful when having 2 optimizers | ||
and one should use accumulated gradients but not the other one. | ||
One could put its own logic to force an optimizer step. | ||
|
||
Return: | ||
None | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Example:: | ||
|
||
def training_step(...): | ||
(opt_a, opt_b) = self.optimizers() | ||
loss = ... | ||
# automatically applies scaling, etc... | ||
self.manual_backward(loss, opt_a) | ||
# This will force an opt.step() even if accumulate_grad_batches is set. | ||
self.manual_optimizer_step(opt_a, force_optimizer_step=True) | ||
|
||
""" | ||
# make sure we're using manual opt | ||
self._verify_is_manual_optimization('manual_optimizer_step') | ||
|
||
if not self.trainer.train_loop.should_accumulate() or force_optimizer_step: | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# mock closure function as the user is responsible to call `manual_backward` | ||
def mock_optimizer_closure(): | ||
return | ||
|
||
self.trainer.train_loop.optimizer_step(optimizer, None, self.trainer.batch_idx, mock_optimizer_closure) | ||
|
||
# update will be called after every optimizer_step call | ||
if self.trainer.amp_backend == AMPType.NATIVE: | ||
self.trainer.scaler.update() | ||
|
||
def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: | ||
""" | ||
|
@@ -1118,7 +1168,8 @@ def backward(self, loss, optimizer, optimizer_idx): | |
loss.backward() | ||
|
||
""" | ||
loss.backward(*args, **kwargs) | ||
if self.trainer.train_loop.automatic_optimization or self._running_manual_backward: | ||
loss.backward(*args, **kwargs) | ||
|
||
def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how this one is created, well we can keep just curious what process did it :]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this one, so when moving doc from build to _build and serving it with sphinx-serve. They don't get added.