Skip to content

Commit 502adbc

Browse files
awaelchlicarmoccaananthsubjustusschock
authored
refactor optimizer loop logic for manual and automatic optimization (#7526)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
1 parent bf46730 commit 502adbc

File tree

8 files changed

+101
-94
lines changed

8 files changed

+101
-94
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3838
* Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
3939
* Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))
4040
* Moved attributes `hiddens` and `split_idx` to TrainLoop ([#7507](https://github.com/PyTorchLightning/pytorch-lightning/pull/7507))
41+
* Refactored the logic around manual and automatic optimization inside the optimizer loop ([#7526](https://github.com/PyTorchLightning/pytorch-lightning/pull/7526))
4142

4243
- `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238))
4344

pytorch_lightning/callbacks/finetuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def _store(
282282

283283
def on_train_epoch_start(self, trainer, pl_module):
284284
"""Called when the epoch begins."""
285-
for opt_idx, optimizer in trainer.train_loop.prepare_optimizers():
285+
for opt_idx, optimizer in trainer.train_loop.get_active_optimizers():
286286
num_param_groups = len(optimizer.param_groups)
287287
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
288288
current_param_groups = optimizer.param_groups

pytorch_lightning/core/optimizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import types
1514
from contextlib import contextmanager
1615
from typing import Callable, Optional
1716
from weakref import proxy
@@ -207,7 +206,7 @@ def closure_dis():
207206
profiler_name = "closure_{self._optimizer_idx}"
208207
closure = do_nothing_closure
209208
else:
210-
if not isinstance(closure, types.FunctionType):
209+
if not callable(closure):
211210
raise MisconfigurationException("When closure is provided, it should be a function")
212211
profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"
213212

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,7 @@ def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
998998
self.optimizer_connector.update_learning_rates(
999999
interval='epoch',
10001000
opt_indices=[
1001-
opt_idx for opt_idx, _ in self.train_loop.get_optimizers_iterable(
1001+
opt_idx for opt_idx, _ in self.train_loop.get_active_optimizers(
10021002
batch_idx=(self.train_loop.total_batch_idx - 1)
10031003
) # Select the optimizers which were used in the last batch of the epoch
10041004
],

pytorch_lightning/trainer/training_loop.py

Lines changed: 95 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
from collections import OrderedDict
1616
from contextlib import contextmanager, suppress
1717
from copy import copy, deepcopy
18-
from typing import Any, Dict, List, Optional, Union
18+
from functools import partial, update_wrapper
19+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1920

2021
import numpy as np
2122
import torch
23+
from torch.optim import Optimizer
2224

2325
from pytorch_lightning.core.optimizer import LightningOptimizer
2426
from pytorch_lightning.core.step_result import Result
@@ -82,9 +84,8 @@ def __init__(
8284
self.trainer.num_sanity_val_steps = num_sanity_val_steps
8385

8486
@property
85-
def num_optimizers(self):
86-
num_optimizers = len(self.get_optimizers_iterable())
87-
return num_optimizers
87+
def num_active_optimizers(self) -> int:
88+
return len(self.get_active_optimizers())
8889

8990
@property
9091
def optimizer_freq_cumsum(self):
@@ -234,23 +235,25 @@ def _should_add_batch_output_to_epoch_output(self) -> bool:
234235

235236
return False
236237

237-
def get_optimizers_iterable(self, batch_idx=None):
238+
def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]:
238239
"""
239-
Generates an iterable with (idx, optimizer) for each optimizer.
240+
Returns the currently active optimizers. When multiple optimizers are used with different frequencies,
241+
only one of the optimizers is active at a time.
242+
243+
Returns:
244+
A list of tuples (opt_idx, optimizer) of currently active optimizers.
240245
"""
241246
if not self.trainer.optimizer_frequencies:
242247
# call training_step once per optimizer
243248
return list(enumerate(self.trainer.optimizers))
244249

245-
if batch_idx is None:
246-
batch_idx = self.total_batch_idx
247-
250+
batch_idx = self.total_batch_idx if batch_idx is None else batch_idx
248251
optimizers_loop_length = self.optimizer_freq_cumsum[-1]
249252
current_place_in_loop = batch_idx % optimizers_loop_length
250253

251254
# find optimzier index by looking for the first {item > current_place} in the cumsum list
252-
opt_idx = np.argmax(self.optimizer_freq_cumsum > current_place_in_loop)
253-
return [[opt_idx, self.trainer.optimizers[opt_idx]]]
255+
opt_idx = int(np.argmax(self.optimizer_freq_cumsum > current_place_in_loop))
256+
return [(opt_idx, self.trainer.optimizers[opt_idx])]
254257

255258
def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
256259
training_step_output.detach()
@@ -471,7 +474,7 @@ def run_training_epoch(self):
471474
train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)
472475

473476
# track epoch output
474-
epoch_output = [[] for _ in range(self.num_optimizers)]
477+
epoch_output = [[] for _ in range(self.num_active_optimizers)]
475478

476479
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
477480
dataloader_idx = 0
@@ -660,7 +663,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
660663
# bookkeeping
661664
self._hiddens = None
662665

663-
optimizers = self.prepare_optimizers()
666+
optimizers = list(enumerate(self.trainer.optimizers))
664667

665668
# track all outputs across time and num of optimizers
666669
batch_outputs = [[] for _ in range(len(optimizers))]
@@ -689,69 +692,88 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
689692
for split_idx, split_batch in enumerate(splits):
690693
self.split_idx = split_idx
691694

692-
# create an iterable for optimizers and loop over them
693-
for opt_idx, optimizer in optimizers:
694-
695-
# toggle model params + set info to logger_connector
696-
self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer)
697-
698-
result = AttributeDict()
699-
if self.should_accumulate():
700-
# For gradient accumulation
701-
702-
# -------------------
703-
# calculate loss (train step + train step end)
704-
# -------------------
695+
if self.trainer.lightning_module.automatic_optimization:
696+
for opt_idx, optimizer in self.get_active_optimizers(batch_idx):
697+
result = self._run_optimization(batch_idx, split_idx, split_batch, opt_idx, optimizer)
698+
if result:
699+
batch_outputs[opt_idx].append(result.training_step_output_for_epoch_end)
700+
grad_norm_dict = result.get("grad_norm_dict", {})
701+
else:
702+
# in manual optimization, there is no looping over optimizers
703+
result = self._run_optimization(batch_idx, split_idx, split_batch)
704+
if result:
705+
batch_outputs[0].append(result.training_step_output_for_epoch_end)
706+
707+
output = AttributeDict(
708+
signal=0,
709+
# todo: Properly aggregate grad_norm accros opt_idx and split_idx
710+
grad_norm_dict=grad_norm_dict,
711+
training_step_output_for_epoch_end=batch_outputs,
712+
)
713+
return output
705714

706-
# automatic_optimization=True: perform dpp sync only when performing optimizer_step
707-
# automatic_optimization=False: don't block synchronization here
708-
with self.block_ddp_sync_behaviour():
709-
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self._hiddens)
715+
def _run_optimization(self, batch_idx, split_idx, split_batch, opt_idx=0, optimizer=None):
716+
# TODO: In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change
717+
# opt_idx=0 to opt_idx=None in the signature here
710718

711-
# ------------------------------
712-
# BACKWARD PASS
713-
# ------------------------------
714-
# gradient update with accumulated gradients
715-
else:
716-
if self.trainer.lightning_module.automatic_optimization:
719+
# toggle model params + set info to logger_connector
720+
self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer)
717721

718-
def train_step_and_backward_closure():
719-
nonlocal result
720-
result = self.training_step_and_backward(
721-
split_batch, batch_idx, opt_idx, optimizer, self._hiddens
722-
)
723-
return None if result is None else result.loss
722+
result = AttributeDict()
723+
closure = self.make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result)
724724

725-
# optimizer step
726-
self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
725+
if self.should_accumulate():
726+
# For gradient accumulation
727727

728-
else:
729-
result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens)
728+
# -------------------
729+
# calculate loss (train step + train step end)
730+
# -------------------
730731

731-
if not result:
732-
# user decided to skip optimization
733-
# make sure to zero grad.
734-
continue
732+
# automatic_optimization=True: perform ddp sync only when performing optimizer_step
733+
# automatic_optimization=False: don't block synchronization here
734+
with self.block_ddp_sync_behaviour():
735+
closure()
735736

736-
# todo: Properly aggregate grad_norm accros opt_idx and split_idx
737-
grad_norm_dict = result.get("grad_norm_dict", {})
737+
# ------------------------------
738+
# BACKWARD PASS
739+
# ------------------------------
740+
# gradient update with accumulated gradients
741+
else:
742+
if self.trainer.lightning_module.automatic_optimization:
743+
self.optimizer_step(optimizer, opt_idx, batch_idx, closure)
744+
else:
745+
result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens)
738746

739-
# update running loss + reset accumulated loss
740-
self.update_running_loss(result.loss)
747+
if not result:
748+
# user decided to skip optimization
749+
return result
741750

742-
batch_outputs = self._process_closure_result(
743-
opt_closure_result=result,
744-
batch_outputs=batch_outputs,
745-
opt_idx=opt_idx,
746-
)
751+
# update running loss + reset accumulated loss
752+
self.update_running_loss(result.loss)
747753

748-
result = AttributeDict(
749-
signal=0,
750-
grad_norm_dict=grad_norm_dict,
751-
training_step_output_for_epoch_end=batch_outputs,
752-
)
754+
self._process_closure_result(result)
753755
return result
754756

757+
def training_step_and_backward_closure(
758+
self,
759+
split_batch: Any,
760+
batch_idx: int,
761+
opt_idx: int,
762+
optimizer: Optimizer,
763+
hiddens,
764+
return_result: AttributeDict,
765+
) -> Optional[torch.Tensor]:
766+
767+
step_result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens)
768+
if step_result is not None:
769+
return_result.update(step_result)
770+
return return_result.loss
771+
772+
def make_closure(self, *closure_args, **closure_kwargs: Any) -> Callable:
773+
""" Wraps the training step closure into a partial object which will be called within ``optimizer.step``. """
774+
partial_func = partial(self.training_step_and_backward_closure, *closure_args, **closure_kwargs)
775+
return update_wrapper(partial_func, self.training_step_and_backward_closure)
776+
755777
@contextmanager
756778
def block_ddp_sync_behaviour(self, should_block_sync: bool = False):
757779
"""
@@ -776,22 +798,16 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False):
776798
else:
777799
yield None
778800

779-
def _process_closure_result(
780-
self, opt_closure_result: Optional[AttributeDict], batch_outputs: list, opt_idx: int
781-
) -> list:
782-
if opt_closure_result:
783-
# cache metrics
784-
self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result)
785-
786-
# check if loss or model weights are nan
787-
if self.trainer.terminate_on_nan:
788-
self._check_finite(opt_closure_result.loss)
801+
def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -> None:
802+
if not opt_closure_result:
803+
return
789804

790-
# track all the outputs across all steps
791-
batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0
792-
batch_outputs[batch_opt_idx].append(opt_closure_result.training_step_output_for_epoch_end)
805+
# cache metrics
806+
self.trainer.logger_connector.cache_training_step_metrics(opt_closure_result)
793807

794-
return batch_outputs
808+
# check if loss or model weights are nan
809+
if self.trainer.terminate_on_nan:
810+
self._check_finite(opt_closure_result.loss)
795811

796812
def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens):
797813
"""Wrap forward, zero_grad and backward in a closure so second order methods work"""
@@ -863,7 +879,7 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None):
863879
self.trainer.optimizer_connector.update_learning_rates(
864880
interval="step",
865881
monitor_metrics=monitor_metrics,
866-
opt_indices=[opt_idx for opt_idx, _ in self.get_optimizers_iterable()],
882+
opt_indices=[opt_idx for opt_idx, _ in self.get_active_optimizers()],
867883
)
868884

869885
def increment_accumulated_grad_global_step(self):
@@ -961,13 +977,6 @@ def save_loggers_on_train_batch_end(self):
961977
if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None:
962978
self.trainer.logger.save()
963979

964-
def prepare_optimizers(self):
965-
# in manual optimization we loop over all optimizers at once
966-
optimizers = self.get_optimizers_iterable()
967-
if not self.trainer.lightning_module.automatic_optimization:
968-
optimizers = [optimizers[0]]
969-
return optimizers
970-
971980
def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer):
972981
# make sure only the gradients of the current optimizer's parameters are calculated
973982
# in the training step to prevent dangling gradients in multiple-optimizer setup.

tests/accelerators/test_accelerator_connector.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,6 @@ def test_accelerator_choice_ddp_kubeflow(device_count_mock, setup_distributed_mo
295295
class CB(Callback):
296296

297297
def on_fit_start(self, trainer, pl_module):
298-
assert trainer.accelerator_connector.use_ddp
299298
assert isinstance(trainer.accelerator, GPUAccelerator)
300299
assert isinstance(trainer.training_type_plugin, DDPPlugin)
301300
assert isinstance(trainer.training_type_plugin.cluster_environment, KubeflowEnvironment)
@@ -331,7 +330,6 @@ def test_accelerator_choice_ddp_cpu_kubeflow(device_count_mock, setup_distribute
331330
class CB(Callback):
332331

333332
def on_fit_start(self, trainer, pl_module):
334-
assert trainer.accelerator_connector.use_ddp
335333
assert isinstance(trainer.accelerator, CPUAccelerator)
336334
assert isinstance(trainer.training_type_plugin, DDPPlugin)
337335
assert isinstance(trainer.training_type_plugin.cluster_environment, KubeflowEnvironment)

tests/core/test_lightning_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def training_epoch_end(self, outputs):
243243
...
244244

245245
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **_):
246-
assert optimizer_closure.__name__ == "train_step_and_backward_closure"
246+
assert optimizer_closure.__name__ == "training_step_and_backward_closure"
247247
# not passing the closure to the optimizer because step is mocked
248248
# zero_grad is called inside the closure
249249
if isinstance(optimizer, SGD) and batch_idx % 2 == 0:

tests/deprecated_api/test_remove_1-5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def __init__(self):
196196
self.automatic_optimization = False
197197

198198
def training_step(self, batch, batch_idx, optimizer_idx):
199-
assert optimizer_idx is not None
199+
assert optimizer_idx == 0
200200
return super().training_step(batch, batch_idx)
201201

202202
def configure_optimizers(self):

0 commit comments

Comments
 (0)