Skip to content

Commit c37abbf

Browse files
committed
Add test for loading a model from a checkpoint with SWA parameters
1 parent c8db9d8 commit c37abbf

File tree

2 files changed

+54
-7
lines changed

2 files changed

+54
-7
lines changed

pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,11 @@ def _update_batch_norm_moments(
255255
prev_momenta = {}
256256
self._batch_norm_moments = {}
257257

258+
train_data_fetcher = trainer.data_connector.train_data_fetcher
259+
if train_data_fetcher is None:
260+
# Training data not yet connected, could be in a validation sanity check
261+
return
262+
258263
was_training = pl_module.training
259264
pl_module.train()
260265

@@ -274,7 +279,7 @@ def _update_batch_norm_moments(
274279
module.num_batches_tracked *= 0
275280

276281
# Recompute mean and variance for all batch norm layers by doing a full pass over the training data
277-
for batch, _ in trainer.data_connector.train_data_fetcher:
282+
for batch, _ in train_data_fetcher:
278283
batch = batch.to(pl_module.device)
279284
pl_module(batch)
280285

@@ -316,7 +321,7 @@ def on_save_checkpoint(
316321
"swa_lrs": self._swa_lrs,
317322
"annealing_epochs": self._annealing_epochs,
318323
"annealing_strategy": self._annealing_strategy,
319-
"average_model_parameters": self._get_average_model_parameters(),
324+
"average_model_parameters": self._get_average_model_parameters(trainer),
320325
}
321326
return checkpoint_data
322327

@@ -380,8 +385,10 @@ def restore_average_parameters_from_checkpoint(
380385
p_model.detach().copy_(p_swa_)
381386
return True
382387

383-
def _get_average_model_parameters(self) -> Any:
384-
if self._average_model is None:
388+
def _get_average_model_parameters(self, trainer: "pl.Trainer") -> Any:
389+
if self._average_model is None or not (self.swa_start <= trainer.current_epoch <= self.swa_end):
390+
# If we're not within the SWA epochs then when loading checkpoint data we would want
391+
# to use parameters from the underlying model rather than the SWA parameters.
385392
return None
386393
parameters = []
387394
for p_swa in self._average_model.parameters():

tests/callbacks/test_stochastic_weight_avg.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from pytorch_lightning import LightningModule, Trainer
2626
from pytorch_lightning.accelerators import Accelerator
27-
from pytorch_lightning.callbacks import StochasticWeightAveraging
27+
from pytorch_lightning.callbacks import StochasticWeightAveraging, ModelCheckpoint
2828
from pytorch_lightning.plugins import DDPSpawnPlugin
2929
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
3030
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -46,6 +46,7 @@ def __init__(
4646
self.iterable_dataset = iterable_dataset
4747
self.crash_after_epoch = crash_after_epoch
4848
self._epoch_count = 0
49+
self.save_hyperparameters()
4950

5051
def training_step(self, batch, batch_idx):
5152
output = self.forward(batch)
@@ -55,6 +56,7 @@ def training_step(self, batch, batch_idx):
5556
def validation_step(self, batch, batch_idx):
5657
output = self.forward(batch)
5758
loss = self.loss(batch, output)
59+
self.log("val_loss", loss)
5860
return {"x": loss}
5961

6062
def train_dataloader(self):
@@ -142,7 +144,7 @@ def on_train_end(self, trainer, pl_module):
142144
assert self.update_parameters_calls == expected_update_calls
143145
if self._swa_validation:
144146
# 3 weight transfers are needed per SWA validation step
145-
assert self.transfer_weights_calls == (self.validation_calls - self._swa_epoch_start) * 3 + 1
147+
assert self.transfer_weights_calls == (self.validation_calls - self.swa_start) * 3 + 1
146148
else:
147149
assert self.transfer_weights_calls == 1
148150

@@ -169,7 +171,8 @@ def train_with_swa(
169171
enable_progress_bar=False,
170172
max_epochs=max_epochs,
171173
limit_train_batches=5,
172-
limit_val_batches=1.0 if validation else 0.0,
174+
limit_val_batches=5 if validation else 0,
175+
num_sanity_val_steps=0,
173176
callbacks=[swa_callback],
174177
accumulate_grad_batches=2,
175178
strategy=strategy,
@@ -362,3 +365,40 @@ def test_swa_resume_training_from_checkpoint(tmpdir):
362365

363366
with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward):
364367
trainer.fit(model)
368+
369+
370+
@pytest.mark.parametrize("batchnorm", (True, False))
371+
@pytest.mark.parametrize("within_swa_epochs", (True, False))
372+
def test_swa_load_best_checkpoint(tmpdir, batchnorm: bool, within_swa_epochs: bool):
373+
model = SwaTestModel(batchnorm=batchnorm)
374+
if within_swa_epochs:
375+
# Start at epoch 1 so we can guarantee the best checkpoint should be saved with SWA weights
376+
swa_start = 1
377+
else:
378+
# Start after the last epoch, so we never save a checkpoint with SWA parameters
379+
swa_start = 6
380+
max_epochs = 5
381+
382+
swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1, swa_validation=True)
383+
checkpoint_callback = ModelCheckpoint(monitor='val_loss', save_top_k=3, mode='min')
384+
385+
trainer = Trainer(
386+
default_root_dir=tmpdir,
387+
enable_progress_bar=False,
388+
max_epochs=max_epochs,
389+
limit_train_batches=5,
390+
limit_val_batches=5,
391+
num_sanity_val_steps=0,
392+
callbacks=[swa_callback, checkpoint_callback],
393+
accumulate_grad_batches=2,
394+
num_processes=1,
395+
)
396+
397+
with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward):
398+
trainer.fit(model)
399+
400+
checkpoint_path = checkpoint_callback.best_model_path
401+
new_model = SwaTestModel.load_from_checkpoint(checkpoint_path)
402+
parameters_loaded = SwaTestCallback.restore_average_parameters_from_checkpoint(new_model, checkpoint_path)
403+
404+
assert parameters_loaded == within_swa_epochs

0 commit comments

Comments
 (0)