Skip to content

Commit 1696273

Browse files
committed
Allow restoring SWA parameters to a model from a checkpoint
1 parent 3d2bf65 commit 1696273

File tree

1 file changed

+48
-1
lines changed

1 file changed

+48
-1
lines changed

pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1717
"""
1818
from copy import deepcopy
19-
from typing import Any, Callable, Dict, List, Optional, Union
19+
from typing import Any, Callable, Dict, IO, List, Optional, Type, Union
2020

2121
import torch
2222
from torch import nn
@@ -26,6 +26,7 @@
2626
from pytorch_lightning.callbacks.base import Callback
2727
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
2828
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
29+
from pytorch_lightning.utilities.cloud_io import load as pl_load
2930
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3031

3132
_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor]
@@ -340,6 +341,52 @@ def on_load_checkpoint(
340341
f"Checkpoint has no data for the {self.state_key} callback, not initializing the callback state."
341342
)
342343

344+
@classmethod
345+
def restore_average_parameters_from_checkpoint(
346+
cls,
347+
pl_module: "pl.LightningModule",
348+
checkpoint_path: Union[str, IO],
349+
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
350+
) -> bool:
351+
r"""
352+
Set model weights to the SWA averaged weights saved in a checkpoint.
353+
354+
Arguments:
355+
pl_module: The module to set weights on
356+
checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object
357+
map_location:
358+
If your checkpoint saved a GPU model and you now load on CPUs
359+
or a different number of GPUs, use this to map to the new setup.
360+
The behaviour is the same as in :func:`torch.load`.
361+
362+
Return:
363+
A `bool` indicating whether averaged weights were loaded. If `False`, this means the checkpoint is
364+
from an epoch before the SWA epoch start.
365+
"""
366+
if map_location is not None:
367+
checkpoint = pl_load(checkpoint_path, map_location=map_location)
368+
else:
369+
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
370+
callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks")
371+
if not callback_states:
372+
raise ValueError("callback states are not present in the checkpoint")
373+
374+
state_key = cls.__qualname__ # Default state key defined in Callback base class
375+
state = callback_states.get(state_key)
376+
if not state:
377+
raise ValueError(f"no {state_key} state found in the checkpoint")
378+
state = deepcopy(state)
379+
average_model_parameters = state["average_model_parameters"]
380+
381+
if not average_model_parameters:
382+
return False
383+
384+
for p_model, p_swa in zip(pl_module.parameters(), average_model_parameters):
385+
device = p_model.device
386+
p_swa_ = p_swa.detach().to(device)
387+
p_model.detach().copy_(p_swa_)
388+
return True
389+
343390
def _get_average_model_parameters(self) -> Any:
344391
if self._average_model is None:
345392
return None

0 commit comments

Comments
 (0)