2424
2525from pytorch_lightning import LightningModule , Trainer
2626from pytorch_lightning .accelerators import Accelerator
27- from pytorch_lightning .callbacks import StochasticWeightAveraging
27+ from pytorch_lightning .callbacks import StochasticWeightAveraging , ModelCheckpoint
2828from pytorch_lightning .plugins import DDPSpawnPlugin
2929from pytorch_lightning .trainer .connectors .data_connector import _PatchDataLoader
3030from pytorch_lightning .utilities .exceptions import MisconfigurationException
@@ -46,6 +46,7 @@ def __init__(
4646 self .iterable_dataset = iterable_dataset
4747 self .crash_after_epoch = crash_after_epoch
4848 self ._epoch_count = 0
49+ self .save_hyperparameters ()
4950
5051 def training_step (self , batch , batch_idx ):
5152 output = self .forward (batch )
@@ -55,6 +56,7 @@ def training_step(self, batch, batch_idx):
5556 def validation_step (self , batch , batch_idx ):
5657 output = self .forward (batch )
5758 loss = self .loss (batch , output )
59+ self .log ("val_loss" , loss )
5860 return {"x" : loss }
5961
6062 def train_dataloader (self ):
@@ -142,7 +144,7 @@ def on_train_end(self, trainer, pl_module):
142144 assert self .update_parameters_calls == expected_update_calls
143145 if self ._swa_validation :
144146 # 3 weight transfers are needed per SWA validation step
145- assert self .transfer_weights_calls == (self .validation_calls - self ._swa_epoch_start ) * 3 + 1
147+ assert self .transfer_weights_calls == (self .validation_calls - self .swa_start ) * 3 + 1
146148 else :
147149 assert self .transfer_weights_calls == 1
148150
@@ -169,7 +171,8 @@ def train_with_swa(
169171 enable_progress_bar = False ,
170172 max_epochs = max_epochs ,
171173 limit_train_batches = 5 ,
172- limit_val_batches = 1.0 if validation else 0.0 ,
174+ limit_val_batches = 5 if validation else 0 ,
175+ num_sanity_val_steps = 0 ,
173176 callbacks = [swa_callback ],
174177 accumulate_grad_batches = 2 ,
175178 strategy = strategy ,
@@ -362,3 +365,40 @@ def test_swa_resume_training_from_checkpoint(tmpdir):
362365
363366 with mock .patch .object (Accelerator , "backward" , wraps = trainer .accelerator .backward ):
364367 trainer .fit (model )
368+
369+
370+ @pytest .mark .parametrize ("batchnorm" , (True , False ))
371+ @pytest .mark .parametrize ("within_swa_epochs" , (True , False ))
372+ def test_swa_load_best_checkpoint (tmpdir , batchnorm : bool , within_swa_epochs : bool ):
373+ model = SwaTestModel (batchnorm = batchnorm )
374+ if within_swa_epochs :
375+ # Start at epoch 1 so we can guarantee the best checkpoint should be saved with SWA weights
376+ swa_start = 1
377+ else :
378+ # Start after the last epoch, so we never save a checkpoint with SWA parameters
379+ swa_start = 6
380+ max_epochs = 5
381+
382+ swa_callback = SwaTestCallback (swa_epoch_start = swa_start , swa_lrs = 0.1 , swa_validation = True )
383+ checkpoint_callback = ModelCheckpoint (monitor = 'val_loss' , save_top_k = 3 , mode = 'min' )
384+
385+ trainer = Trainer (
386+ default_root_dir = tmpdir ,
387+ enable_progress_bar = False ,
388+ max_epochs = max_epochs ,
389+ limit_train_batches = 5 ,
390+ limit_val_batches = 5 ,
391+ num_sanity_val_steps = 0 ,
392+ callbacks = [swa_callback , checkpoint_callback ],
393+ accumulate_grad_batches = 2 ,
394+ num_processes = 1 ,
395+ )
396+
397+ with mock .patch .object (Accelerator , "backward" , wraps = trainer .accelerator .backward ):
398+ trainer .fit (model )
399+
400+ checkpoint_path = checkpoint_callback .best_model_path
401+ new_model = SwaTestModel .load_from_checkpoint (checkpoint_path )
402+ parameters_loaded = SwaTestCallback .restore_average_parameters_from_checkpoint (new_model , checkpoint_path )
403+
404+ assert parameters_loaded == within_swa_epochs
0 commit comments