Skip to content

Commit 41d0687

Browse files
adamreevetchatonpre-commit-ci[bot]awaelchlicarmocca
authored andcommitted
Support checkpoint save and load with Stochastic Weight Averaging (Lightning-AI#9938)
Co-authored-by: thomas chaton <thomas@grid.ai> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: Kushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com> Co-authored-by: Jirka <jirka.borovec@seznam.cz> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
1 parent 76f6aa9 commit 41d0687

File tree

3 files changed

+195
-14
lines changed

3 files changed

+195
-14
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7373
- Fixed incorrect `precision="mixed"` being used with `DeepSpeedStrategy` and `IPUStrategy` ([#14041](https://github.com/Lightning-AI/lightning/pull/14041))
7474

7575

76+
- Fixed resuming from a checkpoint when using Stochastic Weight Averaging (SWA) ([#9938](https://github.com/Lightning-AI/lightning/pull/9938))
77+
78+
7679
- Fixed dtype inference during gradient norm computation ([#14051](https://github.com/Lightning-AI/lightning/pull/14051))
7780

7881

src/pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1717
"""
1818
from copy import deepcopy
19-
from typing import Any, Callable, cast, List, Optional, Union
19+
from typing import Any, Callable, cast, Dict, List, Optional, Union
2020

2121
import torch
2222
from torch import nn, Tensor
2323
from torch.optim.swa_utils import SWALR
2424

2525
import pytorch_lightning as pl
2626
from pytorch_lightning.callbacks.callback import Callback
27+
from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy
2728
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2829
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
2930
from pytorch_lightning.utilities.types import _LRScheduler, LRSchedulerConfig
@@ -112,15 +113,22 @@ def __init__(
112113
if device is not None and not isinstance(device, (torch.device, str)):
113114
raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}")
114115

116+
self.n_averaged: Optional[torch.Tensor] = None
115117
self._swa_epoch_start = swa_epoch_start
116118
self._swa_lrs = swa_lrs
117119
self._annealing_epochs = annealing_epochs
118120
self._annealing_strategy = annealing_strategy
119121
self._avg_fn = avg_fn or self.avg_fn
120122
self._device = device
121-
self._max_epochs: int
122-
self._model_contains_batch_norm: bool
123+
self._model_contains_batch_norm: Optional[bool] = None
123124
self._average_model: "pl.LightningModule"
125+
self._initialized = False
126+
self._swa_scheduler: Optional[_LRScheduler] = None
127+
self._scheduler_state: Optional[Dict] = None
128+
self._init_n_averaged = 0
129+
self._latest_update_epoch = -1
130+
self.momenta: Optional[Dict[nn.modules.batchnorm._BatchNorm, float]] = None
131+
self._max_epochs: int
124132

125133
@property
126134
def swa_start(self) -> int:
@@ -147,6 +155,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
147155
if len(trainer.lr_scheduler_configs) > 1:
148156
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")
149157

158+
if isinstance(trainer.strategy, (DDPFullyShardedStrategy, DeepSpeedStrategy)):
159+
raise MisconfigurationException("SWA does not currently support sharded models.")
160+
150161
if isinstance(self._swa_epoch_start, float):
151162
self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)
152163

@@ -158,8 +169,13 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
158169
assert trainer.fit_loop.max_epochs is not None
159170
trainer.fit_loop.max_epochs += 1
160171

172+
if self._scheduler_state is not None:
173+
self._clear_schedulers(trainer)
174+
161175
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
162-
if trainer.current_epoch == self.swa_start:
176+
if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end):
177+
self._initialized = True
178+
163179
# move average model to request device.
164180
self._average_model = self._average_model.to(self._device or pl_module.device)
165181

@@ -180,6 +196,17 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
180196
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
181197
),
182198
)
199+
if self._scheduler_state is not None:
200+
# Restore scheduler state from checkpoint
201+
self._swa_scheduler.load_state_dict(self._scheduler_state)
202+
elif trainer.current_epoch != self.swa_start:
203+
# Log a warning if we're initializing after start without any checkpoint data,
204+
# as behaviour will be different compared to having checkpoint data.
205+
rank_zero_warn(
206+
"SWA is initializing after swa_start without any checkpoint data. "
207+
"This may be caused by loading a checkpoint from an older version of PyTorch Lightning."
208+
)
209+
183210
# We assert that there is only one optimizer on fit start, so know opt_idx is always 0
184211
default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0)
185212
assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1
@@ -196,14 +223,18 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
196223
else:
197224
trainer.lr_scheduler_configs.append(default_scheduler_cfg)
198225

199-
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
226+
if self.n_averaged is None:
227+
self.n_averaged = torch.tensor(self._init_n_averaged, dtype=torch.long, device=pl_module.device)
200228

201-
if self.swa_start <= trainer.current_epoch <= self.swa_end:
229+
if (self.swa_start <= trainer.current_epoch <= self.swa_end) and (
230+
trainer.current_epoch > self._latest_update_epoch
231+
):
232+
assert self.n_averaged is not None
202233
self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn)
234+
self._latest_update_epoch = trainer.current_epoch
203235

204236
# Note: No > here in case the callback is saved with the model and training continues
205237
if trainer.current_epoch == self.swa_end + 1:
206-
207238
# Transfer weights from average model to pl_module
208239
self.transfer_weights(self._average_model, pl_module)
209240

@@ -265,6 +296,7 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> No
265296

266297
def reset_momenta(self) -> None:
267298
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""
299+
assert self.momenta is not None
268300
for bn_module in self.momenta:
269301
bn_module.momentum = self.momenta[bn_module]
270302

@@ -285,3 +317,35 @@ def update_parameters(
285317
def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor:
286318
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97."""
287319
return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)
320+
321+
def state_dict(self) -> Dict[str, Any]:
322+
return {
323+
"n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(),
324+
"latest_update_epoch": self._latest_update_epoch,
325+
"scheduler_state": None if self._swa_scheduler is None else self._swa_scheduler.state_dict(),
326+
"average_model_state": None if self._average_model is None else self._average_model.state_dict(),
327+
}
328+
329+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
330+
self._init_n_averaged = state_dict["n_averaged"]
331+
self._latest_update_epoch = state_dict["latest_update_epoch"]
332+
self._scheduler_state = state_dict["scheduler_state"]
333+
self._load_average_model_state(state_dict["average_model_state"])
334+
335+
@staticmethod
336+
def _clear_schedulers(trainer: "pl.Trainer") -> None:
337+
# If we have scheduler state saved, clear the scheduler configs so that we don't try to
338+
# load state into the wrong type of schedulers when restoring scheduler checkpoint state.
339+
# We'll configure the scheduler and re-load its state in on_train_epoch_start.
340+
# Note that this relies on the callback state being restored before the scheduler state is
341+
# restored, and doesn't work if restore_checkpoint_after_setup is True, but at the time of
342+
# writing that is only True for deepspeed which is already not supported by SWA.
343+
# See https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 for background.
344+
if trainer.lr_scheduler_configs:
345+
assert len(trainer.lr_scheduler_configs) == 1
346+
trainer.lr_scheduler_configs.clear()
347+
348+
def _load_average_model_state(self, model_state: Any) -> None:
349+
if self._average_model is None:
350+
return
351+
self._average_model.load_state_dict(model_state)

tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py

Lines changed: 121 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
import os
16+
from pathlib import Path
17+
from typing import ContextManager, Optional
1518
from unittest import mock
1619

1720
import pytest
1821
import torch
1922
from torch import nn
23+
from torch.optim.lr_scheduler import LambdaLR
2024
from torch.optim.swa_utils import SWALR
2125
from torch.utils.data import DataLoader
2226

@@ -30,7 +34,9 @@
3034

3135

3236
class SwaTestModel(BoringModel):
33-
def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False):
37+
def __init__(
38+
self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False, crash_on_epoch=None
39+
):
3440
super().__init__()
3541
layers = [nn.Linear(32, 32)]
3642
if batchnorm:
@@ -39,17 +45,18 @@ def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dat
3945
self.layer = nn.Sequential(*layers)
4046
self.interval = interval
4147
self.iterable_dataset = iterable_dataset
48+
self.crash_on_epoch = crash_on_epoch
4249

4350
def training_step(self, batch, batch_idx):
51+
if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch:
52+
raise Exception("SWA crash test")
4453
output = self.forward(batch)
4554
loss = self.loss(batch, output)
4655
return {"loss": loss}
4756

4857
def train_dataloader(self):
49-
5058
dset_cls = RandomIterableDataset if self.iterable_dataset else RandomDataset
5159
dset = dset_cls(32, 64)
52-
5360
return DataLoader(dset, batch_size=2)
5461

5562
def configure_optimizers(self):
@@ -66,6 +73,8 @@ def configure_optimizers(self):
6673
class SwaTestCallback(StochasticWeightAveraging):
6774
update_parameters_calls: int = 0
6875
transfer_weights_calls: int = 0
76+
# Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0
77+
first_epoch: Optional[int] = None
6978

7079
def update_parameters(self, *args, **kwargs):
7180
self.update_parameters_calls += 1
@@ -77,6 +86,11 @@ def transfer_weights(self, *args, **kwargs):
7786

7887
def on_train_epoch_start(self, trainer, *args):
7988
super().on_train_epoch_start(trainer, *args)
89+
if self.first_epoch is None and not trainer.fit_loop.restarting:
90+
# since the checkpoint loaded was saved `on_train_epoch_end`, the first `FitLoop` iteration will
91+
# not update the model and just call the epoch-level hooks, for that reason, we check that we are not
92+
# restarting before choosing the first epoch
93+
self.first_epoch = trainer.current_epoch
8094
assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end)
8195
if self.swa_start <= trainer.current_epoch:
8296
assert isinstance(trainer.lr_scheduler_configs[0].scheduler, SWALR)
@@ -88,6 +102,7 @@ def on_train_epoch_end(self, trainer, *args):
88102
if self.swa_start <= trainer.current_epoch <= self.swa_end:
89103
swa_epoch = trainer.current_epoch - self.swa_start
90104
assert self.n_averaged == swa_epoch + 1
105+
assert self._swa_scheduler is not None
91106
# Scheduler is stepped once on initialization and then at the end of each epoch
92107
assert self._swa_scheduler._step_count == swa_epoch + 2
93108
elif trainer.current_epoch > self.swa_end:
@@ -103,10 +118,13 @@ def on_train_end(self, trainer, pl_module):
103118

104119
if not isinstance(trainer.strategy, DDPSpawnStrategy):
105120
# check backward call count. the batchnorm update epoch should not backward
106-
assert trainer.strategy.backward.call_count == trainer.max_epochs * trainer.limit_train_batches
121+
assert trainer.strategy.backward.call_count == (
122+
(trainer.max_epochs - self.first_epoch) * trainer.limit_train_batches
123+
)
107124

108125
# check call counts
109-
assert self.update_parameters_calls == trainer.max_epochs - (self._swa_epoch_start - 1)
126+
first_swa_epoch = max(self.first_epoch, self.swa_start)
127+
assert self.update_parameters_calls == trainer.max_epochs - first_swa_epoch
110128
assert self.transfer_weights_calls == 1
111129

112130

@@ -140,7 +158,7 @@ def train_with_swa(
140158
devices=devices,
141159
)
142160

143-
with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward):
161+
with _backward_patch(trainer):
144162
trainer.fit(model)
145163

146164
# check the model is the expected
@@ -226,9 +244,10 @@ def test_swa_multiple_lrs(tmpdir):
226244

227245
class TestModel(BoringModel):
228246
def __init__(self):
229-
super(BoringModel, self).__init__()
247+
super().__init__()
230248
self.layer1 = torch.nn.Linear(32, 32)
231249
self.layer2 = torch.nn.Linear(32, 2)
250+
self.on_train_epoch_start_called = False
232251

233252
def forward(self, x):
234253
x = self.layer1(x)
@@ -255,3 +274,98 @@ def on_train_epoch_start(self):
255274
)
256275
trainer.fit(model)
257276
assert model.on_train_epoch_start_called
277+
278+
279+
def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False):
280+
swa_start = 3
281+
trainer_kwargs = {
282+
"default_root_dir": tmpdir,
283+
"max_epochs": 5,
284+
"accelerator": "cpu",
285+
"strategy": "ddp_spawn_find_unused_parameters_false" if ddp else None,
286+
"devices": 2 if ddp else 1,
287+
"limit_train_batches": 5,
288+
"limit_val_batches": 0,
289+
"accumulate_grad_batches": 2,
290+
"enable_progress_bar": False,
291+
}
292+
trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs)
293+
294+
with _backward_patch(trainer), pytest.raises(Exception, match="SWA crash test"):
295+
trainer.fit(model)
296+
297+
checkpoint_dir = Path(tmpdir) / "lightning_logs" / "version_0" / "checkpoints"
298+
checkpoint_files = os.listdir(checkpoint_dir)
299+
assert len(checkpoint_files) == 1
300+
ckpt_path = str(checkpoint_dir / checkpoint_files[0])
301+
302+
trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs)
303+
304+
with _backward_patch(trainer):
305+
trainer.fit(resume_model, ckpt_path=ckpt_path)
306+
307+
308+
class CustomSchedulerModel(SwaTestModel):
309+
def configure_optimizers(self):
310+
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
311+
312+
def lr_lambda(current_step: int):
313+
return 0.1
314+
315+
scheduler = LambdaLR(optimizer, lr_lambda, -1)
316+
return {
317+
"optimizer": optimizer,
318+
"lr_scheduler": {
319+
"scheduler": scheduler,
320+
"interval": self.interval,
321+
},
322+
}
323+
324+
325+
@pytest.mark.parametrize("crash_on_epoch", [1, 3])
326+
def test_swa_resume_training_from_checkpoint(tmpdir, crash_on_epoch):
327+
model = SwaTestModel(crash_on_epoch=crash_on_epoch)
328+
resume_model = SwaTestModel()
329+
_swa_resume_training_from_checkpoint(tmpdir, model, resume_model)
330+
331+
332+
@pytest.mark.parametrize("crash_on_epoch", [1, 3])
333+
def test_swa_resume_training_from_checkpoint_custom_scheduler(tmpdir, crash_on_epoch):
334+
# Reproduces the bug reported in https://github.com/PyTorchLightning/pytorch-lightning/issues/11665
335+
model = CustomSchedulerModel(crash_on_epoch=crash_on_epoch)
336+
resume_model = CustomSchedulerModel()
337+
_swa_resume_training_from_checkpoint(tmpdir, model, resume_model)
338+
339+
340+
@RunIf(skip_windows=True)
341+
def test_swa_resume_training_from_checkpoint_ddp(tmpdir):
342+
model = SwaTestModel(crash_on_epoch=3)
343+
resume_model = SwaTestModel()
344+
_swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=True)
345+
346+
347+
@pytest.mark.parametrize(
348+
"strategy",
349+
[
350+
pytest.param("fsdp", marks=RunIf(fairscale_fully_sharded=True, min_cuda_gpus=1)),
351+
pytest.param("deepspeed", marks=RunIf(deepspeed=True, min_cuda_gpus=1)),
352+
],
353+
)
354+
def test_misconfiguration_error_with_sharded_model(tmpdir, strategy: str):
355+
model = SwaTestModel()
356+
swa_callback = SwaTestCallback(swa_epoch_start=2, swa_lrs=0.1)
357+
trainer = Trainer(
358+
default_root_dir=tmpdir,
359+
enable_progress_bar=False,
360+
max_epochs=5,
361+
callbacks=[swa_callback],
362+
strategy=strategy,
363+
accelerator="gpu",
364+
devices=1,
365+
)
366+
with pytest.raises(MisconfigurationException, match="SWA does not currently support sharded models"):
367+
trainer.fit(model)
368+
369+
370+
def _backward_patch(trainer: Trainer) -> ContextManager:
371+
return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward)

0 commit comments

Comments
 (0)