Skip to content

Commit

Permalink
add legacy load utility (#9166)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
4 people authored Sep 23, 2021
1 parent 491e4a2 commit 87b11fb
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 47 deletions.
9 changes: 6 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546))


- Added `pl_legacy_patch` load utility for loading old checkpoints that have pickled legacy Lightning attributes ([#9166](https://github.com/PyTorchLightning/pytorch-lightning/pull/9166))


### Changed

- `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)).
Expand Down Expand Up @@ -242,9 +245,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098)


- Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162))


- Deprecated `on_keyboard_interrupt` callback hook in favor of new `on_exception` hook ([#9260](https://github.com/PyTorchLightning/pytorch-lightning/pull/9260))


Expand Down Expand Up @@ -325,6 +325,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated `profiled_functions` argument from `PyTorchProfiler` ([#9178](https://github.com/PyTorchLightning/pytorch-lightning/pull/9178))


- Removed deprecated `pytorch_lighting.utilities.argparse_utils` module ([#9166](https://github.com/PyTorchLightning/pytorch-lightning/pull/9166))


- Removed deprecated property `Trainer.running_sanity_check` in favor of `Trainer.sanity_checking` ([#9209](https://github.com/PyTorchLightning/pytorch-lightning/pull/9209))


Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.migration import pl_legacy_patch
from pytorch_lightning.utilities.parsing import parse_class_init_keys

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -125,10 +126,11 @@ def load_from_checkpoint(
pretrained_model.freeze()
y_hat = pretrained_model(x)
"""
if map_location is not None:
checkpoint = pl_load(checkpoint_path, map_location=map_location)
else:
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
with pl_legacy_patch():
if map_location is not None:
checkpoint = pl_load(checkpoint_path, map_location=map_location)
else:
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

if hparams_file is not None:
extension = hparams_file.split(".")[-1]
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.migration import pl_legacy_patch
from pytorch_lightning.utilities.types import _PATH
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

Expand Down Expand Up @@ -65,7 +66,8 @@ def resume_start(self) -> None:
self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path)

def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path)
with pl_legacy_patch():
loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path)
if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS):
raise ValueError(
"The checkpoint you're attempting to load follows an"
Expand Down
7 changes: 0 additions & 7 deletions pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,6 @@ def _gpus_allowed_type(x: str) -> Union[int, str]:
return int(x)


def _gpus_arg_default(x: str) -> Union[int, str]: # pragma: no-cover
# unused, but here for backward compatibility with old checkpoints that need to be able to
# unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8
# see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
pass


def _int_or_float_type(x: Union[int, float, str]) -> Union[int, float]:
if "." in str(x):
return float(x)
Expand Down
7 changes: 0 additions & 7 deletions pytorch_lightning/utilities/argparse_utils.py

This file was deleted.

48 changes: 48 additions & 0 deletions pytorch_lightning/utilities/migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from types import ModuleType

import pytorch_lightning.utilities.argparse


class pl_legacy_patch:
"""Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for
unpickling old checkpoints. The following patches apply.
1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to
version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4,
but still needs to be available for import for legacy checkpoints.
Example:
with pl_legacy_patch():
torch.load("path/to/legacy/checkpoint.ckpt")
"""

def __enter__(self):
# `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse`
legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils")
sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module

# `_gpus_arg_default` used to be imported from these locations
legacy_argparse_module._gpus_arg_default = lambda x: x
pytorch_lightning.utilities.argparse._gpus_arg_default = lambda x: x
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"):
delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default")
del sys.modules["pytorch_lightning.utilities.argparse_utils"]
4 changes: 3 additions & 1 deletion pytorch_lightning/utilities/upgrade_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities.migration import pl_legacy_patch

KEYS_MAPPING = {
"checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"),
Expand Down Expand Up @@ -58,4 +59,5 @@ def upgrade_checkpoint(filepath):

log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.")
copyfile(args.file, args.file + ".bak")
upgrade_checkpoint(args.file)
with pl_legacy_patch():
upgrade_checkpoint(args.file)
24 changes: 0 additions & 24 deletions tests/deprecated_api/test_remove_2-0.py

This file was deleted.

36 changes: 36 additions & 0 deletions tests/utilities/test_migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys

import pytorch_lightning
from pytorch_lightning.utilities.migration import pl_legacy_patch


def test_patch_legacy_argparse_utils():
with pl_legacy_patch():
from pytorch_lightning.utilities import argparse_utils

assert callable(argparse_utils._gpus_arg_default)
assert "pytorch_lightning.utilities.argparse_utils" in sys.modules

assert "pytorch_lightning.utilities.argparse_utils" not in sys.modules


def test_patch_legacy_gpus_arg_default():
with pl_legacy_patch():
from pytorch_lightning.utilities.argparse import _gpus_arg_default

assert callable(_gpus_arg_default)
assert not hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default")
assert not hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default")

0 comments on commit 87b11fb

Please sign in to comment.