1515from collections import OrderedDict
1616from contextlib import contextmanager , suppress
1717from copy import copy , deepcopy
18- from typing import Any , Dict , List , Optional , Union
18+ from functools import partial , update_wrapper
19+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1920
2021import numpy as np
2122import torch
23+ from torch .optim import Optimizer
2224
2325from pytorch_lightning .core .optimizer import LightningOptimizer
2426from pytorch_lightning .core .step_result import Result
@@ -82,9 +84,8 @@ def __init__(
8284 self .trainer .num_sanity_val_steps = num_sanity_val_steps
8385
8486 @property
85- def num_optimizers (self ):
86- num_optimizers = len (self .get_optimizers_iterable ())
87- return num_optimizers
87+ def num_active_optimizers (self ) -> int :
88+ return len (self .get_active_optimizers ())
8889
8990 @property
9091 def optimizer_freq_cumsum (self ):
@@ -234,23 +235,25 @@ def _should_add_batch_output_to_epoch_output(self) -> bool:
234235
235236 return False
236237
237- def get_optimizers_iterable (self , batch_idx = None ):
238+ def get_active_optimizers (self , batch_idx : Optional [ int ] = None ) -> List [ Tuple [ int , Optimizer ]] :
238239 """
239- Generates an iterable with (idx, optimizer) for each optimizer.
240+ Returns the currently active optimizers. When multiple optimizers are used with different frequencies,
241+ only one of the optimizers is active at a time.
242+
243+ Returns:
244+ A list of tuples (opt_idx, optimizer) of currently active optimizers.
240245 """
241246 if not self .trainer .optimizer_frequencies :
242247 # call training_step once per optimizer
243248 return list (enumerate (self .trainer .optimizers ))
244249
245- if batch_idx is None :
246- batch_idx = self .total_batch_idx
247-
250+ batch_idx = self .total_batch_idx if batch_idx is None else batch_idx
248251 optimizers_loop_length = self .optimizer_freq_cumsum [- 1 ]
249252 current_place_in_loop = batch_idx % optimizers_loop_length
250253
251254 # find optimzier index by looking for the first {item > current_place} in the cumsum list
252- opt_idx = np .argmax (self .optimizer_freq_cumsum > current_place_in_loop )
253- return [[ opt_idx , self .trainer .optimizers [opt_idx ]] ]
255+ opt_idx = int ( np .argmax (self .optimizer_freq_cumsum > current_place_in_loop ) )
256+ return [( opt_idx , self .trainer .optimizers [opt_idx ]) ]
254257
255258 def on_after_backward (self , training_step_output , batch_idx , untouched_loss ):
256259 training_step_output .detach ()
@@ -471,7 +474,7 @@ def run_training_epoch(self):
471474 train_dataloader = self .trainer .accelerator .process_dataloader (self .trainer .train_dataloader )
472475
473476 # track epoch output
474- epoch_output = [[] for _ in range (self .num_optimizers )]
477+ epoch_output = [[] for _ in range (self .num_active_optimizers )]
475478
476479 train_dataloader = self .trainer .data_connector .get_profiled_train_dataloader (train_dataloader )
477480 dataloader_idx = 0
@@ -660,7 +663,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
660663 # bookkeeping
661664 self ._hiddens = None
662665
663- optimizers = self .prepare_optimizers ( )
666+ optimizers = list ( enumerate ( self .trainer . optimizers ) )
664667
665668 # track all outputs across time and num of optimizers
666669 batch_outputs = [[] for _ in range (len (optimizers ))]
@@ -689,69 +692,88 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
689692 for split_idx , split_batch in enumerate (splits ):
690693 self .split_idx = split_idx
691694
692- # create an iterable for optimizers and loop over them
693- for opt_idx , optimizer in optimizers :
694-
695- # toggle model params + set info to logger_connector
696- self .run_train_split_start (split_idx , split_batch , opt_idx , optimizer )
697-
698- result = AttributeDict ()
699- if self .should_accumulate ():
700- # For gradient accumulation
701-
702- # -------------------
703- # calculate loss (train step + train step end)
704- # -------------------
695+ if self .trainer .lightning_module .automatic_optimization :
696+ for opt_idx , optimizer in self .get_active_optimizers (batch_idx ):
697+ result = self ._run_optimization (batch_idx , split_idx , split_batch , opt_idx , optimizer )
698+ if result :
699+ batch_outputs [opt_idx ].append (result .training_step_output_for_epoch_end )
700+ grad_norm_dict = result .get ("grad_norm_dict" , {})
701+ else :
702+ # in manual optimization, there is no looping over optimizers
703+ result = self ._run_optimization (batch_idx , split_idx , split_batch )
704+ if result :
705+ batch_outputs [0 ].append (result .training_step_output_for_epoch_end )
706+
707+ output = AttributeDict (
708+ signal = 0 ,
709+ # todo: Properly aggregate grad_norm accros opt_idx and split_idx
710+ grad_norm_dict = grad_norm_dict ,
711+ training_step_output_for_epoch_end = batch_outputs ,
712+ )
713+ return output
705714
706- # automatic_optimization=True: perform dpp sync only when performing optimizer_step
707- # automatic_optimization=False: don't block synchronization here
708- with self .block_ddp_sync_behaviour ():
709- self .training_step_and_backward (split_batch , batch_idx , opt_idx , optimizer , self ._hiddens )
715+ def _run_optimization (self , batch_idx , split_idx , split_batch , opt_idx = 0 , optimizer = None ):
716+ # TODO: In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change
717+ # opt_idx=0 to opt_idx=None in the signature here
710718
711- # ------------------------------
712- # BACKWARD PASS
713- # ------------------------------
714- # gradient update with accumulated gradients
715- else :
716- if self .trainer .lightning_module .automatic_optimization :
719+ # toggle model params + set info to logger_connector
720+ self .run_train_split_start (split_idx , split_batch , opt_idx , optimizer )
717721
718- def train_step_and_backward_closure ():
719- nonlocal result
720- result = self .training_step_and_backward (
721- split_batch , batch_idx , opt_idx , optimizer , self ._hiddens
722- )
723- return None if result is None else result .loss
722+ result = AttributeDict ()
723+ closure = self .make_closure (split_batch , batch_idx , opt_idx , optimizer , self ._hiddens , result )
724724
725- # optimizer step
726- self . optimizer_step ( optimizer , opt_idx , batch_idx , train_step_and_backward_closure )
725+ if self . should_accumulate ():
726+ # For gradient accumulation
727727
728- else :
729- result = self .training_step (split_batch , batch_idx , opt_idx , self ._hiddens )
728+ # -------------------
729+ # calculate loss (train step + train step end)
730+ # -------------------
730731
731- if not result :
732- # user decided to skip optimization
733- # make sure to zero grad.
734- continue
732+ # automatic_optimization=True: perform ddp sync only when performing optimizer_step
733+ # automatic_optimization=False: don't block synchronization here
734+ with self . block_ddp_sync_behaviour ():
735+ closure ()
735736
736- # todo: Properly aggregate grad_norm accros opt_idx and split_idx
737- grad_norm_dict = result .get ("grad_norm_dict" , {})
737+ # ------------------------------
738+ # BACKWARD PASS
739+ # ------------------------------
740+ # gradient update with accumulated gradients
741+ else :
742+ if self .trainer .lightning_module .automatic_optimization :
743+ self .optimizer_step (optimizer , opt_idx , batch_idx , closure )
744+ else :
745+ result = self .training_step (split_batch , batch_idx , opt_idx , self ._hiddens )
738746
739- # update running loss + reset accumulated loss
740- self .update_running_loss (result .loss )
747+ if not result :
748+ # user decided to skip optimization
749+ return result
741750
742- batch_outputs = self ._process_closure_result (
743- opt_closure_result = result ,
744- batch_outputs = batch_outputs ,
745- opt_idx = opt_idx ,
746- )
751+ # update running loss + reset accumulated loss
752+ self .update_running_loss (result .loss )
747753
748- result = AttributeDict (
749- signal = 0 ,
750- grad_norm_dict = grad_norm_dict ,
751- training_step_output_for_epoch_end = batch_outputs ,
752- )
754+ self ._process_closure_result (result )
753755 return result
754756
757+ def training_step_and_backward_closure (
758+ self ,
759+ split_batch : Any ,
760+ batch_idx : int ,
761+ opt_idx : int ,
762+ optimizer : Optimizer ,
763+ hiddens ,
764+ return_result : AttributeDict ,
765+ ) -> Optional [torch .Tensor ]:
766+
767+ step_result = self .training_step_and_backward (split_batch , batch_idx , opt_idx , optimizer , hiddens )
768+ if step_result is not None :
769+ return_result .update (step_result )
770+ return return_result .loss
771+
772+ def make_closure (self , * closure_args , ** closure_kwargs : Any ) -> Callable :
773+ """ Wraps the training step closure into a partial object which will be called within ``optimizer.step``. """
774+ partial_func = partial (self .training_step_and_backward_closure , * closure_args , ** closure_kwargs )
775+ return update_wrapper (partial_func , self .training_step_and_backward_closure )
776+
755777 @contextmanager
756778 def block_ddp_sync_behaviour (self , should_block_sync : bool = False ):
757779 """
@@ -776,22 +798,16 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False):
776798 else :
777799 yield None
778800
779- def _process_closure_result (
780- self , opt_closure_result : Optional [AttributeDict ], batch_outputs : list , opt_idx : int
781- ) -> list :
782- if opt_closure_result :
783- # cache metrics
784- self .trainer .logger_connector .cache_training_step_metrics (opt_closure_result )
785-
786- # check if loss or model weights are nan
787- if self .trainer .terminate_on_nan :
788- self ._check_finite (opt_closure_result .loss )
801+ def _process_closure_result (self , opt_closure_result : Optional [AttributeDict ]) -> None :
802+ if not opt_closure_result :
803+ return
789804
790- # track all the outputs across all steps
791- batch_opt_idx = opt_idx if len (batch_outputs ) > 1 else 0
792- batch_outputs [batch_opt_idx ].append (opt_closure_result .training_step_output_for_epoch_end )
805+ # cache metrics
806+ self .trainer .logger_connector .cache_training_step_metrics (opt_closure_result )
793807
794- return batch_outputs
808+ # check if loss or model weights are nan
809+ if self .trainer .terminate_on_nan :
810+ self ._check_finite (opt_closure_result .loss )
795811
796812 def training_step_and_backward (self , split_batch , batch_idx , opt_idx , optimizer , hiddens ):
797813 """Wrap forward, zero_grad and backward in a closure so second order methods work"""
@@ -863,7 +879,7 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None):
863879 self .trainer .optimizer_connector .update_learning_rates (
864880 interval = "step" ,
865881 monitor_metrics = monitor_metrics ,
866- opt_indices = [opt_idx for opt_idx , _ in self .get_optimizers_iterable ()],
882+ opt_indices = [opt_idx for opt_idx , _ in self .get_active_optimizers ()],
867883 )
868884
869885 def increment_accumulated_grad_global_step (self ):
@@ -961,13 +977,6 @@ def save_loggers_on_train_batch_end(self):
961977 if should_flush_logs and self .trainer .is_global_zero and self .trainer .logger is not None :
962978 self .trainer .logger .save ()
963979
964- def prepare_optimizers (self ):
965- # in manual optimization we loop over all optimizers at once
966- optimizers = self .get_optimizers_iterable ()
967- if not self .trainer .lightning_module .automatic_optimization :
968- optimizers = [optimizers [0 ]]
969- return optimizers
970-
971980 def run_train_split_start (self , split_idx , split_batch , opt_idx , optimizer ):
972981 # make sure only the gradients of the current optimizer's parameters are calculated
973982 # in the training step to prevent dangling gradients in multiple-optimizer setup.
0 commit comments