|
16 | 16 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
17 | 17 | """ |
18 | 18 | from copy import deepcopy |
19 | | -from typing import Any, Callable, Dict, List, Optional, Union |
| 19 | +from typing import Any, Callable, Dict, IO, List, Optional, Type, Union |
20 | 20 |
|
21 | 21 | import torch |
22 | 22 | from torch import nn |
|
26 | 26 | from pytorch_lightning.callbacks.base import Callback |
27 | 27 | from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config |
28 | 28 | from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn |
| 29 | +from pytorch_lightning.utilities.cloud_io import load as pl_load |
29 | 30 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
30 | 31 |
|
31 | 32 | _AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor] |
@@ -340,6 +341,52 @@ def on_load_checkpoint( |
340 | 341 | f"Checkpoint has no data for the {self.state_key} callback, not initializing the callback state." |
341 | 342 | ) |
342 | 343 |
|
| 344 | + @classmethod |
| 345 | + def restore_average_parameters_from_checkpoint( |
| 346 | + cls, |
| 347 | + pl_module: "pl.LightningModule", |
| 348 | + checkpoint_path: Union[str, IO], |
| 349 | + map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, |
| 350 | + ) -> bool: |
| 351 | + r""" |
| 352 | + Set model weights to the SWA averaged weights saved in a checkpoint. |
| 353 | +
|
| 354 | + Arguments: |
| 355 | + pl_module: The module to set weights on |
| 356 | + checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object |
| 357 | + map_location: |
| 358 | + If your checkpoint saved a GPU model and you now load on CPUs |
| 359 | + or a different number of GPUs, use this to map to the new setup. |
| 360 | + The behaviour is the same as in :func:`torch.load`. |
| 361 | +
|
| 362 | + Return: |
| 363 | + A `bool` indicating whether averaged weights were loaded. If `False`, this means the checkpoint is |
| 364 | + from an epoch before the SWA epoch start. |
| 365 | + """ |
| 366 | + if map_location is not None: |
| 367 | + checkpoint = pl_load(checkpoint_path, map_location=map_location) |
| 368 | + else: |
| 369 | + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) |
| 370 | + callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") |
| 371 | + if not callback_states: |
| 372 | + raise ValueError("callback states are not present in the checkpoint") |
| 373 | + |
| 374 | + state_key = cls.__qualname__ # Default state key defined in Callback base class |
| 375 | + state = callback_states.get(state_key) |
| 376 | + if not state: |
| 377 | + raise ValueError(f"no {state_key} state found in the checkpoint") |
| 378 | + state = deepcopy(state) |
| 379 | + average_model_parameters = state["average_model_parameters"] |
| 380 | + |
| 381 | + if not average_model_parameters: |
| 382 | + return False |
| 383 | + |
| 384 | + for p_model, p_swa in zip(pl_module.parameters(), average_model_parameters): |
| 385 | + device = p_model.device |
| 386 | + p_swa_ = p_swa.detach().to(device) |
| 387 | + p_model.detach().copy_(p_swa_) |
| 388 | + return True |
| 389 | + |
343 | 390 | def _get_average_model_parameters(self) -> Any: |
344 | 391 | if self._average_model is None: |
345 | 392 | return None |
|
0 commit comments