1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import logging
15+ import os
16+ from pathlib import Path
17+ from typing import ContextManager , Optional
1518from unittest import mock
1619
1720import pytest
1821import torch
1922from torch import nn
23+ from torch .optim .lr_scheduler import LambdaLR
2024from torch .optim .swa_utils import SWALR
2125from torch .utils .data import DataLoader
2226
3034
3135
3236class SwaTestModel (BoringModel ):
33- def __init__ (self , batchnorm : bool = True , interval : str = "epoch" , iterable_dataset : bool = False ):
37+ def __init__ (
38+ self , batchnorm : bool = True , interval : str = "epoch" , iterable_dataset : bool = False , crash_on_epoch = None
39+ ):
3440 super ().__init__ ()
3541 layers = [nn .Linear (32 , 32 )]
3642 if batchnorm :
@@ -39,17 +45,18 @@ def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dat
3945 self .layer = nn .Sequential (* layers )
4046 self .interval = interval
4147 self .iterable_dataset = iterable_dataset
48+ self .crash_on_epoch = crash_on_epoch
4249
4350 def training_step (self , batch , batch_idx ):
51+ if self .crash_on_epoch and self .trainer .current_epoch >= self .crash_on_epoch :
52+ raise Exception ("SWA crash test" )
4453 output = self .forward (batch )
4554 loss = self .loss (batch , output )
4655 return {"loss" : loss }
4756
4857 def train_dataloader (self ):
49-
5058 dset_cls = RandomIterableDataset if self .iterable_dataset else RandomDataset
5159 dset = dset_cls (32 , 64 )
52-
5360 return DataLoader (dset , batch_size = 2 )
5461
5562 def configure_optimizers (self ):
@@ -66,6 +73,8 @@ def configure_optimizers(self):
6673class SwaTestCallback (StochasticWeightAveraging ):
6774 update_parameters_calls : int = 0
6875 transfer_weights_calls : int = 0
76+ # Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0
77+ first_epoch : Optional [int ] = None
6978
7079 def update_parameters (self , * args , ** kwargs ):
7180 self .update_parameters_calls += 1
@@ -77,6 +86,11 @@ def transfer_weights(self, *args, **kwargs):
7786
7887 def on_train_epoch_start (self , trainer , * args ):
7988 super ().on_train_epoch_start (trainer , * args )
89+ if self .first_epoch is None and not trainer .fit_loop .restarting :
90+ # since the checkpoint loaded was saved `on_train_epoch_end`, the first `FitLoop` iteration will
91+ # not update the model and just call the epoch-level hooks, for that reason, we check that we are not
92+ # restarting before choosing the first epoch
93+ self .first_epoch = trainer .current_epoch
8094 assert trainer .fit_loop ._skip_backward == (trainer .current_epoch > self .swa_end )
8195 if self .swa_start <= trainer .current_epoch :
8296 assert isinstance (trainer .lr_scheduler_configs [0 ].scheduler , SWALR )
@@ -88,6 +102,7 @@ def on_train_epoch_end(self, trainer, *args):
88102 if self .swa_start <= trainer .current_epoch <= self .swa_end :
89103 swa_epoch = trainer .current_epoch - self .swa_start
90104 assert self .n_averaged == swa_epoch + 1
105+ assert self ._swa_scheduler is not None
91106 # Scheduler is stepped once on initialization and then at the end of each epoch
92107 assert self ._swa_scheduler ._step_count == swa_epoch + 2
93108 elif trainer .current_epoch > self .swa_end :
@@ -103,10 +118,13 @@ def on_train_end(self, trainer, pl_module):
103118
104119 if not isinstance (trainer .strategy , DDPSpawnStrategy ):
105120 # check backward call count. the batchnorm update epoch should not backward
106- assert trainer .strategy .backward .call_count == trainer .max_epochs * trainer .limit_train_batches
121+ assert trainer .strategy .backward .call_count == (
122+ (trainer .max_epochs - self .first_epoch ) * trainer .limit_train_batches
123+ )
107124
108125 # check call counts
109- assert self .update_parameters_calls == trainer .max_epochs - (self ._swa_epoch_start - 1 )
126+ first_swa_epoch = max (self .first_epoch , self .swa_start )
127+ assert self .update_parameters_calls == trainer .max_epochs - first_swa_epoch
110128 assert self .transfer_weights_calls == 1
111129
112130
@@ -140,7 +158,7 @@ def train_with_swa(
140158 devices = devices ,
141159 )
142160
143- with mock . patch . object ( Strategy , "backward" , wraps = trainer . strategy . backward ):
161+ with _backward_patch ( trainer ):
144162 trainer .fit (model )
145163
146164 # check the model is the expected
@@ -226,9 +244,10 @@ def test_swa_multiple_lrs(tmpdir):
226244
227245 class TestModel (BoringModel ):
228246 def __init__ (self ):
229- super (BoringModel , self ).__init__ ()
247+ super ().__init__ ()
230248 self .layer1 = torch .nn .Linear (32 , 32 )
231249 self .layer2 = torch .nn .Linear (32 , 2 )
250+ self .on_train_epoch_start_called = False
232251
233252 def forward (self , x ):
234253 x = self .layer1 (x )
@@ -255,3 +274,98 @@ def on_train_epoch_start(self):
255274 )
256275 trainer .fit (model )
257276 assert model .on_train_epoch_start_called
277+
278+
279+ def _swa_resume_training_from_checkpoint (tmpdir , model , resume_model , ddp = False ):
280+ swa_start = 3
281+ trainer_kwargs = {
282+ "default_root_dir" : tmpdir ,
283+ "max_epochs" : 5 ,
284+ "accelerator" : "cpu" ,
285+ "strategy" : "ddp_spawn_find_unused_parameters_false" if ddp else None ,
286+ "devices" : 2 if ddp else 1 ,
287+ "limit_train_batches" : 5 ,
288+ "limit_val_batches" : 0 ,
289+ "accumulate_grad_batches" : 2 ,
290+ "enable_progress_bar" : False ,
291+ }
292+ trainer = Trainer (callbacks = SwaTestCallback (swa_epoch_start = swa_start , swa_lrs = 0.1 ), ** trainer_kwargs )
293+
294+ with _backward_patch (trainer ), pytest .raises (Exception , match = "SWA crash test" ):
295+ trainer .fit (model )
296+
297+ checkpoint_dir = Path (tmpdir ) / "lightning_logs" / "version_0" / "checkpoints"
298+ checkpoint_files = os .listdir (checkpoint_dir )
299+ assert len (checkpoint_files ) == 1
300+ ckpt_path = str (checkpoint_dir / checkpoint_files [0 ])
301+
302+ trainer = Trainer (callbacks = SwaTestCallback (swa_epoch_start = swa_start , swa_lrs = 0.1 ), ** trainer_kwargs )
303+
304+ with _backward_patch (trainer ):
305+ trainer .fit (resume_model , ckpt_path = ckpt_path )
306+
307+
308+ class CustomSchedulerModel (SwaTestModel ):
309+ def configure_optimizers (self ):
310+ optimizer = torch .optim .SGD (self .layer .parameters (), lr = 0.1 )
311+
312+ def lr_lambda (current_step : int ):
313+ return 0.1
314+
315+ scheduler = LambdaLR (optimizer , lr_lambda , - 1 )
316+ return {
317+ "optimizer" : optimizer ,
318+ "lr_scheduler" : {
319+ "scheduler" : scheduler ,
320+ "interval" : self .interval ,
321+ },
322+ }
323+
324+
325+ @pytest .mark .parametrize ("crash_on_epoch" , [1 , 3 ])
326+ def test_swa_resume_training_from_checkpoint (tmpdir , crash_on_epoch ):
327+ model = SwaTestModel (crash_on_epoch = crash_on_epoch )
328+ resume_model = SwaTestModel ()
329+ _swa_resume_training_from_checkpoint (tmpdir , model , resume_model )
330+
331+
332+ @pytest .mark .parametrize ("crash_on_epoch" , [1 , 3 ])
333+ def test_swa_resume_training_from_checkpoint_custom_scheduler (tmpdir , crash_on_epoch ):
334+ # Reproduces the bug reported in https://github.com/PyTorchLightning/pytorch-lightning/issues/11665
335+ model = CustomSchedulerModel (crash_on_epoch = crash_on_epoch )
336+ resume_model = CustomSchedulerModel ()
337+ _swa_resume_training_from_checkpoint (tmpdir , model , resume_model )
338+
339+
340+ @RunIf (skip_windows = True )
341+ def test_swa_resume_training_from_checkpoint_ddp (tmpdir ):
342+ model = SwaTestModel (crash_on_epoch = 3 )
343+ resume_model = SwaTestModel ()
344+ _swa_resume_training_from_checkpoint (tmpdir , model , resume_model , ddp = True )
345+
346+
347+ @pytest .mark .parametrize (
348+ "strategy" ,
349+ [
350+ pytest .param ("fsdp" , marks = RunIf (fairscale_fully_sharded = True , min_cuda_gpus = 1 )),
351+ pytest .param ("deepspeed" , marks = RunIf (deepspeed = True , min_cuda_gpus = 1 )),
352+ ],
353+ )
354+ def test_misconfiguration_error_with_sharded_model (tmpdir , strategy : str ):
355+ model = SwaTestModel ()
356+ swa_callback = SwaTestCallback (swa_epoch_start = 2 , swa_lrs = 0.1 )
357+ trainer = Trainer (
358+ default_root_dir = tmpdir ,
359+ enable_progress_bar = False ,
360+ max_epochs = 5 ,
361+ callbacks = [swa_callback ],
362+ strategy = strategy ,
363+ accelerator = "gpu" ,
364+ devices = 1 ,
365+ )
366+ with pytest .raises (MisconfigurationException , match = "SWA does not currently support sharded models" ):
367+ trainer .fit (model )
368+
369+
370+ def _backward_patch (trainer : Trainer ) -> ContextManager :
371+ return mock .patch .object (Strategy , "backward" , wraps = trainer .strategy .backward )
0 commit comments