@@ -120,7 +120,6 @@ def __init__(
120120 if device is not None and not isinstance (device , (torch .device , str )):
121121 raise MisconfigurationException (f"device is expected to be a torch.device or a str. Found { device } " )
122122
123- self .momenta = None
124123 self .n_averaged = None
125124 self ._swa_epoch_start = swa_epoch_start
126125 self ._swa_lrs = swa_lrs
@@ -134,6 +133,7 @@ def __init__(
134133 self ._temp_model = None
135134 self ._initialized = False
136135 self ._swa_scheduler = None
136+ self ._batch_norm_moments = None
137137
138138 @property
139139 def swa_start (self ) -> int :
@@ -171,12 +171,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
171171 self ._model_contains_batch_norm = self .pl_module_contains_batch_norm (pl_module )
172172
173173 self ._max_epochs = trainer .max_epochs
174- if self ._model_contains_batch_norm :
175- # virtually increase max_epochs to perform batch norm update on latest epoch.
176- trainer .fit_loop .max_epochs += 1
177174
178175 def on_train_epoch_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ):
179- resuming_after_start = trainer . current_epoch > self .swa_start and not self ._initialized
176+ resuming_after_start = ( not self . _initialized ) and ( self .swa_start < trainer . current_epoch <= self .swa_end )
180177 if trainer .current_epoch == self .swa_start or resuming_after_start :
181178 self ._initialized = True
182179
@@ -223,75 +220,73 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
223220 if self .swa_start <= trainer .current_epoch <= self .swa_end :
224221 self .update_parameters (self ._average_model , pl_module , self .n_averaged , self .avg_fn )
225222
226- # Note: No > here in case the callback is saved with the model and training continues
227- if trainer .current_epoch == self .swa_end + 1 :
228-
229- # Transfer weights from average model to pl_module
230- self .transfer_weights (self ._average_model , pl_module )
231-
232- # Reset BatchNorm for update
233- self .reset_batch_norm_and_save_state (pl_module )
234-
235- # There is no need to perform either backward or optimizer.step as we are
236- # performing only one pass over the train data-loader to compute activation statistics
237- # Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
238- trainer .num_training_batches += 1
239- trainer .fit_loop ._skip_backward = True
240- self ._accumulate_grad_batches = trainer .accumulate_grad_batches
241-
242- trainer .accumulate_grad_batches = trainer .num_training_batches
243-
244- def on_train_epoch_end (self , trainer : "pl.Trainer" , * args ):
245- trainer .fit_loop ._skip_backward = False
246-
247223 def on_train_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ):
248- if self ._model_contains_batch_norm and trainer .current_epoch == self .swa_end + 1 :
249- # BatchNorm epoch update. Reset state
250- trainer .accumulate_grad_batches = self ._accumulate_grad_batches
251- trainer .num_training_batches -= 1
252- trainer .fit_loop .max_epochs -= 1
253- self .reset_momenta ()
254- elif trainer .current_epoch == self .swa_end :
224+ if trainer .current_epoch == self .swa_end :
255225 # Last SWA epoch. Transfer weights from average model to pl_module
256226 self .transfer_weights (self ._average_model , pl_module )
227+ if self ._model_contains_batch_norm :
228+ self ._update_batch_norm_moments (trainer , pl_module , store_moments = False )
257229
258230 def on_validation_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
259231 if self ._swa_validation and (self .swa_start <= trainer .current_epoch <= self .swa_end ):
260232 # Take a temporary copy of the model parameters
261233 self .transfer_weights (pl_module , self ._temp_model )
262234 # Update the model with the averaged parameters
263235 self .transfer_weights (self ._average_model , pl_module )
236+ if self ._model_contains_batch_norm :
237+ self ._update_batch_norm_moments (trainer , pl_module )
264238
265239 def on_validation_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
266240 if self ._swa_validation and (self .swa_start <= trainer .current_epoch <= self .swa_end ):
267241 # Copy original model parameters back
268242 self .transfer_weights (self ._temp_model , pl_module )
243+ if self ._model_contains_batch_norm :
244+ self ._restore_batch_norm_moments ()
269245
270246 @staticmethod
271247 def transfer_weights (src_pl_module : "pl.LightningModule" , dst_pl_module : "pl.LightningModule" ):
272248 for src_param , dst_param in zip (src_pl_module .parameters (), dst_pl_module .parameters ()):
273249 dst_param .detach ().copy_ (src_param .to (dst_param .device ))
274250
275- def reset_batch_norm_and_save_state (self , pl_module : "pl.LightningModule" ):
276- """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154."""
277- self .momenta = {}
251+ def _update_batch_norm_moments (
252+ self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , store_moments : bool = True
253+ ):
254+ """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L166."""
255+ prev_momenta = {}
256+ self ._batch_norm_moments = {}
257+
258+ was_training = pl_module .training
259+ pl_module .train ()
260+
278261 for module in pl_module .modules ():
279262 if not isinstance (module , nn .modules .batchnorm ._BatchNorm ):
280263 continue
264+ prev_momenta [module ] = module .momentum
265+ if store_moments :
266+ self ._batch_norm_moments [module ] = (module .running_mean , module .running_var )
281267 module .running_mean = torch .zeros_like (
282268 module .running_mean , device = pl_module .device , dtype = module .running_mean .dtype
283269 )
284270 module .running_var = torch .ones_like (
285271 module .running_var , device = pl_module .device , dtype = module .running_var .dtype
286272 )
287- self .momenta [module ] = module .momentum
288273 module .momentum = None
289274 module .num_batches_tracked *= 0
290275
291- def reset_momenta (self ):
292- """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""
293- for bn_module in self .momenta :
294- bn_module .momentum = self .momenta [bn_module ]
276+ # Recompute mean and variance for all batch norm layers by doing a full pass over the training data
277+ for batch , _ in trainer .data_connector .train_data_fetcher :
278+ batch = batch .to (pl_module .device )
279+ pl_module (batch )
280+
281+ # Reset model state
282+ for bn_module , momenta in prev_momenta .items ():
283+ bn_module .momentum = momenta
284+ pl_module .train (was_training )
285+
286+ def _restore_batch_norm_moments (self ):
287+ for bn_module , (mean , variance ) in self ._batch_norm_moments .items ():
288+ bn_module .running_mean = mean
289+ bn_module .running_var = variance
295290
296291 @staticmethod
297292 def update_parameters (
@@ -317,7 +312,6 @@ def on_save_checkpoint(
317312 self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , checkpoint : Dict [str , Any ]
318313 ) -> dict :
319314 checkpoint_data = {
320- "momenta" : self .momenta ,
321315 "n_averaged" : self .n_averaged ,
322316 "swa_lrs" : self ._swa_lrs ,
323317 "annealing_epochs" : self ._annealing_epochs ,
@@ -330,7 +324,6 @@ def on_load_checkpoint(
330324 self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , callback_state : Dict [str , Any ]
331325 ) -> None :
332326 if callback_state :
333- self .momenta = callback_state ["momenta" ]
334327 self .n_averaged = callback_state ["n_averaged" ]
335328 self ._swa_lrs = callback_state ["swa_lrs" ]
336329 self ._annealing_strategy = callback_state ["annealing_strategy" ]
0 commit comments