From 8f76c2600883ac4947e24569945bdf23c2774370 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 31 Mar 2022 20:07:28 +0530 Subject: [PATCH 1/9] Add LightningDataModule.load_from_checkpoint to enable instantiating DataModules directly from checkpoint --- pytorch_lightning/core/datamodule.py | 76 ++++++++- pytorch_lightning/core/saving.py | 161 ++++++++++++++---- .../connectors/checkpoint_connector.py | 11 ++ setup.cfg | 16 +- tests/models/test_hparams.py | 99 ++++++++--- 5 files changed, 289 insertions(+), 74 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 02011fd7e90bf..a02853934938c 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -13,14 +13,16 @@ # limitations under the License. """LightningDataModule for loading DataLoaders with ease.""" from argparse import ArgumentParser, Namespace -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Union from torch.utils.data import DataLoader, Dataset, IterableDataset from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.core.mixins import HyperparametersMixin +from pytorch_lightning.core.saving import _load_from_checkpoint from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types +from pytorch_lightning.utilities.types import _PATH class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): @@ -53,6 +55,9 @@ def teardown(self): """ name: str = ... + CHECKPOINT_HYPER_PARAMS_KEY = "datamodule_hyper_parameters" + CHECKPOINT_HYPER_PARAMS_NAME = "datamodule_hparams_name" + CHECKPOINT_HYPER_PARAMS_TYPE = "datamodule_hparams_type" def __init__(self, train_transforms=None, val_transforms=None, test_transforms=None, dims=None): super().__init__() @@ -262,3 +267,72 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: state_dict: the datamodule state returned by ``state_dict``. """ pass + + @classmethod + def load_from_checkpoint( + cls, + checkpoint_path: Union[_PATH, IO], + hparams_file: Optional[_PATH] = None, + **kwargs, + ): + r""" + Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint + it stores the arguments passed to ``__init__`` in the checkpoint under ``"datamodule_hyper_parameters"``. + + Any arguments specified through \*\*kwargs will override args stored in ``"datamodule_hyper_parameters"``. + + Args: + checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object + hparams_file: Optional path to a ``.yaml`` or ``.csv`` file with hierarchical structure + as in this example:: + + dataloader: + batch_size: 32 + + You most likely won't need this since Lightning will always save the hyperparameters + to the checkpoint. + However, if your checkpoint weights don't have the hyperparameters saved, + use this method to pass in a ``.yaml`` file with the hparams you'd like to use. + These will be converted into a :class:`~dict` and passed into your + :class:`LightningDataModule` for use. + + If your datamodule's ``hparams`` argument is :class:`~argparse.Namespace` + and ``.yaml`` file has hierarchical structure, you need to refactor your datamodule to treat + ``hparams`` as :class:`~dict`. + \**kwargs: Any extra keyword args needed to init the datamodule. Can also be used to override saved + hyperparameter values. + + Return: + :class:`LightningDataModule` instance with loaded weights and hyperparameters (if available). + + Note: + ``load_from_checkpoint`` is a **class** method. You should use your :class:`LightningDataModule` + **class** to call it instead of the :class:`LightningDataModule` instance. + + Example:: + + # load weights without mapping ... + datamodule = MyLightningDataModule.load_from_checkpoint('path/to/checkpoint.ckpt') + + # or load weights and hyperparameters from separate files. + datamodule = MyLightningDataModule.load_from_checkpoint( + 'path/to/checkpoint.ckpt', + hparams_file='/path/to/hparams_file.yaml' + ) + + # override some of the params with new values + datamodule = MyLightningDataModule.load_from_checkpoint( + PATH, + batch_size=32, + num_workers=10, + ) + + """ + return _load_from_checkpoint( + cls, + checkpoint_path, + map_location=None, + hparams_file=hparams_file, + strict=None, + **kwargs, + ) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index fa0f92eb3b971..0b733bf6d902d 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -26,6 +26,7 @@ import torch import yaml +import pytorch_lightning as pl from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -33,6 +34,7 @@ from pytorch_lightning.utilities.migration import pl_legacy_patch from pytorch_lightning.utilities.parsing import parse_class_init_keys from pytorch_lightning.utilities.rank_zero import rank_zero_warn +from pytorch_lightning.utilities.types import _PATH log = logging.getLogger(__name__) PRIMITIVE_TYPES = (bool, int, float, str) @@ -73,7 +75,7 @@ def load_from_checkpoint( If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in :func:`torch.load`. - hparams_file: Optional path to a .yaml file with hierarchical structure + hparams_file: Optional path to a ``.yaml`` or ``.csv`` file with hierarchical structure as in this example:: drop_prob: 0.2 @@ -83,16 +85,16 @@ def load_from_checkpoint( You most likely won't need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don't have the hyperparameters saved, - use this method to pass in a .yaml file with the hparams you'd like to use. + use this method to pass in a ``.yaml`` file with the hparams you'd like to use. These will be converted into a :class:`~dict` and passed into your :class:`LightningModule` for use. If your model's ``hparams`` argument is :class:`~argparse.Namespace` - and .yaml file has hierarchical structure, you need to refactor your model to treat + and ``.yaml`` file has hierarchical structure, you need to refactor your model to treat ``hparams`` as :class:`~dict`. strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys returned by this module's state dict. - kwargs: Any extra keyword args needed to init the model. Can also be used to override saved + \**kwargs: Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values. Return: @@ -132,37 +134,17 @@ def load_from_checkpoint( pretrained_model.freeze() y_hat = pretrained_model(x) """ - with pl_legacy_patch(): - if map_location is not None: - checkpoint = pl_load(checkpoint_path, map_location=map_location) - else: - checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) - - if hparams_file is not None: - extension = hparams_file.split(".")[-1] - if extension.lower() == "csv": - hparams = load_hparams_from_tags_csv(hparams_file) - elif extension.lower() in ("yml", "yaml"): - hparams = load_hparams_from_yaml(hparams_file) - else: - raise ValueError(".csv, .yml or .yaml is required for `hparams_file`") - - hparams["on_gpu"] = False - - # overwrite hparams by the given file - checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams - - # for past checkpoint need to add the new key - if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: - checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} - # override the hparams with values that were passed in - checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs) - - model = cls._load_model_state(checkpoint, strict=strict, **kwargs) - return model + return _load_from_checkpoint( + cls, + checkpoint_path, + map_location, + hparams_file, + strict, + **kwargs, + ) @classmethod - def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cls_kwargs_new): + def _load_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cls_kwargs_new): cls_spec = inspect.getfullargspec(cls.__init__) cls_init_args_name = inspect.signature(cls.__init__).parameters.keys() @@ -247,6 +229,111 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None: """ +def _load_from_checkpoint( + cls: Union["pl.LightningModule", "pl.LightningDataModule"], + checkpoint_path: Union[str, IO], + map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, + hparams_file: Optional[str] = None, + strict: Optional[bool] = None, + **kwargs, +): + with pl_legacy_patch(): + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + if hparams_file is not None: + extension = hparams_file.split(".")[-1] + if extension.lower() == "csv": + hparams = load_hparams_from_tags_csv(hparams_file) + elif extension.lower() in ("yml", "yaml"): + hparams = load_hparams_from_yaml(hparams_file) + else: + raise ValueError(".csv, .yml or .yaml is required for `hparams_file`") + + hparams["on_gpu"] = False + + # overwrite hparams by the given file + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams + + # for past checkpoint need to add the new key + if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} + # override the hparams with values that were passed in + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs) + + if issubclass(cls, pl.LightningDataModule): + return _load_state(cls, checkpoint, **kwargs) + return _load_state(cls, checkpoint, strict=strict, **kwargs) + + +def _load_state( + cls: Union["pl.LightningModule", "pl.LightningDataModule"], + checkpoint: Dict[str, Any], + strict: Optional[bool] = None, + **cls_kwargs_new, +): + cls_spec = inspect.getfullargspec(cls.__init__) + cls_init_args_name = inspect.signature(cls.__init__).parameters.keys() + + self_var, args_var, kwargs_var = parse_class_init_keys(cls) + drop_names = [n for n in (self_var, args_var, kwargs_var) if n] + cls_init_args_name = list(filter(lambda n: n not in drop_names, cls_init_args_name)) + + cls_kwargs_loaded = {} + # pass in the values we saved automatically + if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: + + if issubclass(cls, pl.LightningModule): + # 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys + for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS: + cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {})) + + # 2. Try to restore model hparams from checkpoint using the new key + _new_hparam_key = cls.CHECKPOINT_HYPER_PARAMS_KEY + cls_kwargs_loaded.update(checkpoint.get(_new_hparam_key)) + + # 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace + cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE)) + + # 4. Update cls_kwargs_new with cls_kwargs_old, such that new has higher priority + args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME) + if args_name and args_name in cls_init_args_name: + cls_kwargs_loaded = {args_name: cls_kwargs_loaded} + + _cls_kwargs = {} + _cls_kwargs.update(cls_kwargs_loaded) + _cls_kwargs.update(cls_kwargs_new) + + if not cls_spec.varkw: + # filter kwargs according to class init unless it allows any argument via kwargs + _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name} + + obj = cls(**_cls_kwargs) + + # give model a chance to load something + obj.on_load_checkpoint(checkpoint) + + if isinstance(obj, pl.LightningDataModule): + return obj + + # load the state_dict on the model automatically + keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict) + + if not strict: + if keys.missing_keys: + rank_zero_warn( + f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}" + ) + if keys.unexpected_keys: + rank_zero_warn( + f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}" + ) + + return obj + + def _convert_loaded_hparams(model_args: dict, hparams_type: Optional[Union[Callable, str]] = None) -> object: """Convert hparams according given type in callable or string (past) format.""" # if not hparams type define @@ -288,7 +375,7 @@ def update_hparams(hparams: dict, updates: dict) -> None: hparams.update({k: v}) -def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: +def load_hparams_from_tags_csv(tags_csv: _PATH) -> Dict[str, Any]: """Load hparams from a file. >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') @@ -311,7 +398,7 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: return tags -def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None: +def save_hparams_to_tags_csv(tags_csv: _PATH, hparams: Union[dict, Namespace]) -> None: fs = get_filesystem(tags_csv) if not fs.isdir(os.path.dirname(tags_csv)): raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.") @@ -327,7 +414,7 @@ def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> writer.writerow({"key": k, "value": v}) -def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict[str, Any]: +def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> Dict[str, Any]: """Load hparams from a file. Args: @@ -360,7 +447,7 @@ def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict return hparams -def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace], use_omegaconf: bool = True) -> None: +def save_hparams_to_yaml(config_yaml: _PATH, hparams: Union[dict, Namespace], use_omegaconf: bool = True) -> None: """ Args: config_yaml: path to new YAML file diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index fa8e5277cf1c4..b3b372b44753b 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -360,6 +360,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: } """ model = self.trainer.lightning_module + datamodule = self.trainer.datamodule checkpoint = { # the epoch and global step are saved for compatibility but they are not relevant for restoration @@ -406,6 +407,16 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: else: checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) + if datamodule and datamodule.hparams: + if hasattr(datamodule, "_hparams_name"): + checkpoint[pl.LightningDataModule.CHECKPOINT_HYPER_PARAMS_NAME] = datamodule._hparams_name + # dump arguments + if _OMEGACONF_AVAILABLE and isinstance(datamodule.hparams, Container): + checkpoint[pl.LightningDataModule.CHECKPOINT_HYPER_PARAMS_KEY] = datamodule.hparams + checkpoint[pl.LightningDataModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(datamodule.hparams) + else: + checkpoint[pl.LightningDataModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(datamodule.hparams) + # dump stateful datamodule datamodule = self.trainer.datamodule if datamodule is not None: diff --git a/setup.cfg b/setup.cfg index 9f908742c0110..773cbc338c3f1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,14 +24,14 @@ addopts = --doctest-modules --color=yes --disable-pytest-warnings -filterwarnings = - # error out on our deprecation warnings - ensures the code and tests are kept up-to-date - error::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning - error::FutureWarning - # warnings from deprecated modules on import - # TODO: remove in 1.7 - ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.decorators - ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.memory +# filterwarnings = +# # error out on our deprecation warnings - ensures the code and tests are kept up-to-date +# error::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning +# error::FutureWarning +# # warnings from deprecated modules on import +# # TODO: remove in 1.7 +# ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.decorators +# ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.memory xfail_strict = true junit_duration_report = call diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 63424c8f47a3a..8c9aa48ed45ef 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -33,6 +33,7 @@ from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, _OMEGACONF_AVAILABLE, AttributeDict, is_picklable from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset +from tests.helpers.boring_model import BoringDataModule from tests.helpers.runif import RunIf if _HYDRA_EXPERIMENTAL_AVAILABLE: @@ -69,73 +70,115 @@ def __init__(self, hparams, *my_args, **my_kwargs): self.save_hyperparameters(hparams) +class SaveHparamsDataModule(BoringDataModule): + """Tests that a model can take an object.""" + + def __init__(self, hparams): + super().__init__() + self.save_hyperparameters(hparams) + + +class SaveHparamsDecoratedDataModule(BoringDataModule): + """Tests that a model can take an object.""" + + @decorate + @decorate + def __init__(self, hparams, *my_args, **my_kwargs): + super().__init__() + self.save_hyperparameters(hparams) + + # ------------------------- # STANDARD TESTS # ------------------------- -def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False): +def _run_standard_hparams_test(tmpdir, model, cls, datamodule=None, try_overwrite=False): """Tests for the existence of an arg 'test_arg=14'.""" - hparam_type = type(model.hparams) + obj = datamodule if issubclass(cls, LightningDataModule) else model + + hparam_type = type(obj.hparams) # test proper property assignments - assert model.hparams.test_arg == 14 + assert obj.hparams.test_arg == 14 # verify we can train trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2) - trainer.fit(model) + trainer.fit(model, datamodule=datamodule if issubclass(cls, LightningDataModule) else None) # make sure the raw checkpoint saved the properties raw_checkpoint_path = _raw_checkpoint_path(trainer) raw_checkpoint = torch.load(raw_checkpoint_path) - assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint - assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14 + assert cls.CHECKPOINT_HYPER_PARAMS_KEY in raw_checkpoint + assert raw_checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]["test_arg"] == 14 # verify that model loads correctly - model2 = cls.load_from_checkpoint(raw_checkpoint_path) - assert model2.hparams.test_arg == 14 + obj2 = cls.load_from_checkpoint(raw_checkpoint_path) + assert obj2.hparams.test_arg == 14 - assert isinstance(model2.hparams, hparam_type) + assert isinstance(obj2.hparams, hparam_type) if try_overwrite: # verify that we can overwrite the property - model3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78) - assert model3.hparams.test_arg == 78 + obj3 = cls.load_from_checkpoint(raw_checkpoint_path, test_arg=78) + assert obj3.hparams.test_arg == 78 return raw_checkpoint_path -@pytest.mark.parametrize("cls", [SaveHparamsModel, SaveHparamsDecoratedModel]) +@pytest.mark.parametrize( + "cls", [SaveHparamsModel, SaveHparamsDecoratedModel, SaveHparamsDataModule, SaveHparamsDecoratedDataModule] +) def test_namespace_hparams(tmpdir, cls): - # init model - model = cls(hparams=Namespace(test_arg=14)) + hparams = Namespace(test_arg=14) + + if issubclass(cls, LightningDataModule): + model = BoringModel() + datamodule = cls(hparams=hparams) + else: + model = cls(hparams=hparams) + datamodule = None # run standard test suite - _run_standard_hparams_test(tmpdir, model, cls) + _run_standard_hparams_test(tmpdir, model, cls, datamodule=datamodule) -@pytest.mark.parametrize("cls", [SaveHparamsModel, SaveHparamsDecoratedModel]) +@pytest.mark.parametrize( + "cls", [SaveHparamsModel, SaveHparamsDecoratedModel, SaveHparamsDataModule, SaveHparamsDecoratedDataModule] +) def test_dict_hparams(tmpdir, cls): - # init model - model = cls(hparams={"test_arg": 14}) + hparams = {"test_arg": 14} + if issubclass(cls, LightningDataModule): + model = BoringModel() + datamodule = cls(hparams=hparams) + else: + model = cls(hparams=hparams) + datamodule = None # run standard test suite - _run_standard_hparams_test(tmpdir, model, cls) + _run_standard_hparams_test(tmpdir, model, cls, datamodule=datamodule) @RunIf(omegaconf=True) -@pytest.mark.parametrize("cls", [SaveHparamsModel, SaveHparamsDecoratedModel]) +@pytest.mark.parametrize( + "cls", [SaveHparamsModel, SaveHparamsDecoratedModel, SaveHparamsDataModule, SaveHparamsDecoratedDataModule] +) def test_omega_conf_hparams(tmpdir, cls): - # init model conf = OmegaConf.create(dict(test_arg=14, mylist=[15.4, dict(a=1, b=2)])) - model = cls(hparams=conf) - assert isinstance(model.hparams, Container) + if issubclass(cls, LightningDataModule): + model = BoringModel() + obj = datamodule = cls(hparams=conf) + else: + obj = model = cls(hparams=conf) + datamodule = None + + assert isinstance(obj.hparams, Container) # run standard test suite - raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, cls) - model2 = cls.load_from_checkpoint(raw_checkpoint_path) - assert isinstance(model2.hparams, Container) + raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, cls, datamodule=datamodule) + obj2 = cls.load_from_checkpoint(raw_checkpoint_path) + assert isinstance(obj2.hparams, Container) # config specific tests - assert model2.hparams.test_arg == 14 - assert model2.hparams.mylist[0] == 15.4 + assert obj2.hparams.test_arg == 14 + assert obj2.hparams.mylist[0] == 15.4 def test_explicit_args_hparams(tmpdir): From ddd96c10eecc1b342c37264088e020285391d4eb Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 31 Mar 2022 20:09:28 +0530 Subject: [PATCH 2/9] cleanup --- pytorch_lightning/core/saving.py | 59 -------------------------------- setup.cfg | 16 ++++----- 2 files changed, 8 insertions(+), 67 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 0b733bf6d902d..654f3c511999e 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -143,65 +143,6 @@ def load_from_checkpoint( **kwargs, ) - @classmethod - def _load_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cls_kwargs_new): - cls_spec = inspect.getfullargspec(cls.__init__) - cls_init_args_name = inspect.signature(cls.__init__).parameters.keys() - - self_var, args_var, kwargs_var = parse_class_init_keys(cls) - drop_names = [n for n in (self_var, args_var, kwargs_var) if n] - cls_init_args_name = list(filter(lambda n: n not in drop_names, cls_init_args_name)) - - cls_kwargs_loaded = {} - # pass in the values we saved automatically - if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: - - # 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys - for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS: - cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {})) - - # 2. Try to restore model hparams from checkpoint using the new key - _new_hparam_key = cls.CHECKPOINT_HYPER_PARAMS_KEY - cls_kwargs_loaded.update(checkpoint.get(_new_hparam_key)) - - # 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace - cls_kwargs_loaded = _convert_loaded_hparams( - cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE) - ) - - # 4. Update cls_kwargs_new with cls_kwargs_old, such that new has higher priority - args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME) - if args_name and args_name in cls_init_args_name: - cls_kwargs_loaded = {args_name: cls_kwargs_loaded} - - _cls_kwargs = {} - _cls_kwargs.update(cls_kwargs_loaded) - _cls_kwargs.update(cls_kwargs_new) - - if not cls_spec.varkw: - # filter kwargs according to class init unless it allows any argument via kwargs - _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name} - - model = cls(**_cls_kwargs) - - # give model a chance to load something - model.on_load_checkpoint(checkpoint) - - # load the state_dict on the model automatically - keys = model.load_state_dict(checkpoint["state_dict"], strict=strict) - - if not strict: - if keys.missing_keys: - rank_zero_warn( - f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}" - ) - if keys.unexpected_keys: - rank_zero_warn( - f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}" - ) - - return model - # ------------------------- # OPTIONAL HOOKS # ------------------------- diff --git a/setup.cfg b/setup.cfg index 773cbc338c3f1..9f908742c0110 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,14 +24,14 @@ addopts = --doctest-modules --color=yes --disable-pytest-warnings -# filterwarnings = -# # error out on our deprecation warnings - ensures the code and tests are kept up-to-date -# error::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning -# error::FutureWarning -# # warnings from deprecated modules on import -# # TODO: remove in 1.7 -# ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.decorators -# ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.memory +filterwarnings = + # error out on our deprecation warnings - ensures the code and tests are kept up-to-date + error::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning + error::FutureWarning + # warnings from deprecated modules on import + # TODO: remove in 1.7 + ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.decorators + ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning:pytorch_lightning.core.memory xfail_strict = true junit_duration_report = call From d1ff482b6d69b10aa2940686a6e6d3a72fdc0bba Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 31 Mar 2022 20:12:34 +0530 Subject: [PATCH 3/9] chlog --- CHANGELOG.md | 2 +- docs/source/common/checkpointing.rst | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d086a82e0ba5..ddc957b99c215 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Add `LightningDataModule.load_from_checkpoint` to support loading datamodules directly from checkpoint ([#12550](https://github.com/PyTorchLightning/pytorch-lightning/pull/12550)) - diff --git a/docs/source/common/checkpointing.rst b/docs/source/common/checkpointing.rst index 31824e828cc7d..2ca16574f2cc0 100644 --- a/docs/source/common/checkpointing.rst +++ b/docs/source/common/checkpointing.rst @@ -31,6 +31,7 @@ A Lightning checkpoint has everything needed to restore a training session inclu - State of all callbacks (for stateful callbacks) - State of datamodule (for stateful datamodules) - The hyperparameters used for that model if passed in as hparams (Argparse.Namespace) +- The hyperparameters used for that datamodule if passed in as hparams (Argparse.Namespace) - State of Loops (if using Fault-Tolerant training) From 78802b8e3a04644eaf6fa7446bafe41aa63cd42f Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 4 Apr 2022 18:24:27 +0530 Subject: [PATCH 4/9] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/core/saving.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 654f3c511999e..da81e4c212560 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -176,13 +176,12 @@ def _load_from_checkpoint( map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, hparams_file: Optional[str] = None, strict: Optional[bool] = None, - **kwargs, -): + **kwargs: Any, +) -> Any: + if map_location is None: + map_location = lambda storage, loc: storage with pl_legacy_patch(): - if map_location is not None: - checkpoint = pl_load(checkpoint_path, map_location=map_location) - else: - checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + checkpoint = pl_load(checkpoint_path, map_location=map_location) if hparams_file is not None: extension = hparams_file.split(".")[-1] @@ -193,14 +192,11 @@ def _load_from_checkpoint( else: raise ValueError(".csv, .yml or .yaml is required for `hparams_file`") - hparams["on_gpu"] = False - # overwrite hparams by the given file checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams # for past checkpoint need to add the new key - if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: - checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} + checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {}) # override the hparams with values that were passed in checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs) @@ -213,8 +209,8 @@ def _load_state( cls: Union["pl.LightningModule", "pl.LightningDataModule"], checkpoint: Dict[str, Any], strict: Optional[bool] = None, - **cls_kwargs_new, -): + **cls_kwargs_new: Any, +) -> Any: cls_spec = inspect.getfullargspec(cls.__init__) cls_init_args_name = inspect.signature(cls.__init__).parameters.keys() From 3a632d15bcfca10f094ec2114367c1e5ba3d1a4a Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 4 Apr 2022 18:29:58 +0530 Subject: [PATCH 5/9] use instance --- .../trainer/connectors/checkpoint_connector.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index b3b372b44753b..212ccd081faeb 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -399,23 +399,23 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: # dump hyper-parameters if model.hparams: if hasattr(model, "_hparams_name"): - checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name + checkpoint[model.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name # dump arguments if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container): - checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams - checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) + checkpoint[model.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams + checkpoint[model.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) else: - checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) + checkpoint[model.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) if datamodule and datamodule.hparams: if hasattr(datamodule, "_hparams_name"): - checkpoint[pl.LightningDataModule.CHECKPOINT_HYPER_PARAMS_NAME] = datamodule._hparams_name + checkpoint[datamodule.CHECKPOINT_HYPER_PARAMS_NAME] = datamodule._hparams_name # dump arguments if _OMEGACONF_AVAILABLE and isinstance(datamodule.hparams, Container): - checkpoint[pl.LightningDataModule.CHECKPOINT_HYPER_PARAMS_KEY] = datamodule.hparams - checkpoint[pl.LightningDataModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(datamodule.hparams) + checkpoint[datamodule.CHECKPOINT_HYPER_PARAMS_KEY] = datamodule.hparams + checkpoint[datamodule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(datamodule.hparams) else: - checkpoint[pl.LightningDataModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(datamodule.hparams) + checkpoint[datamodule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(datamodule.hparams) # dump stateful datamodule datamodule = self.trainer.datamodule From bee09ba26ce471ff32ae86529cf1e6a097050630 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 4 Apr 2022 18:32:20 +0530 Subject: [PATCH 6/9] avoid dup --- .../connectors/checkpoint_connector.py | 30 +++++++------------ 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 212ccd081faeb..8f91e9d4e29ee 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -397,25 +397,17 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: prec_plugin.on_save_checkpoint(checkpoint) # dump hyper-parameters - if model.hparams: - if hasattr(model, "_hparams_name"): - checkpoint[model.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name - # dump arguments - if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container): - checkpoint[model.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams - checkpoint[model.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) - else: - checkpoint[model.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) - - if datamodule and datamodule.hparams: - if hasattr(datamodule, "_hparams_name"): - checkpoint[datamodule.CHECKPOINT_HYPER_PARAMS_NAME] = datamodule._hparams_name - # dump arguments - if _OMEGACONF_AVAILABLE and isinstance(datamodule.hparams, Container): - checkpoint[datamodule.CHECKPOINT_HYPER_PARAMS_KEY] = datamodule.hparams - checkpoint[datamodule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(datamodule.hparams) - else: - checkpoint[datamodule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(datamodule.hparams) + for obj in (model, datamodule): + if obj and obj.hparams: + if obj.hparams: + if hasattr(obj, "_hparams_name"): + checkpoint[obj.CHECKPOINT_HYPER_PARAMS_NAME] = obj._hparams_name + # dump arguments + if _OMEGACONF_AVAILABLE and isinstance(obj.hparams, Container): + checkpoint[obj.CHECKPOINT_HYPER_PARAMS_KEY] = obj.hparams + checkpoint[obj.CHECKPOINT_HYPER_PARAMS_TYPE] = type(obj.hparams) + else: + checkpoint[obj.CHECKPOINT_HYPER_PARAMS_KEY] = dict(obj.hparams) # dump stateful datamodule datamodule = self.trainer.datamodule From 1a23a465d7422e911288400fd88f4322330e1138 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 4 Apr 2022 15:28:55 +0200 Subject: [PATCH 7/9] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ddc957b99c215..c04c74f2f6d63 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Add `LightningDataModule.load_from_checkpoint` to support loading datamodules directly from checkpoint ([#12550](https://github.com/PyTorchLightning/pytorch-lightning/pull/12550)) +- Added `LightningDataModule.load_from_checkpoint` to support loading datamodules directly from checkpoint ([#12550](https://github.com/PyTorchLightning/pytorch-lightning/pull/12550)) - From 96bc07b1cf49feaaff6377bdab5f103f84e64f0c Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 5 Apr 2022 00:13:00 +0530 Subject: [PATCH 8/9] redundant if --- .../trainer/connectors/checkpoint_connector.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 8f91e9d4e29ee..6c5d75a2e41f1 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -399,15 +399,14 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: # dump hyper-parameters for obj in (model, datamodule): if obj and obj.hparams: - if obj.hparams: - if hasattr(obj, "_hparams_name"): - checkpoint[obj.CHECKPOINT_HYPER_PARAMS_NAME] = obj._hparams_name - # dump arguments - if _OMEGACONF_AVAILABLE and isinstance(obj.hparams, Container): - checkpoint[obj.CHECKPOINT_HYPER_PARAMS_KEY] = obj.hparams - checkpoint[obj.CHECKPOINT_HYPER_PARAMS_TYPE] = type(obj.hparams) - else: - checkpoint[obj.CHECKPOINT_HYPER_PARAMS_KEY] = dict(obj.hparams) + if hasattr(obj, "_hparams_name"): + checkpoint[obj.CHECKPOINT_HYPER_PARAMS_NAME] = obj._hparams_name + # dump arguments + if _OMEGACONF_AVAILABLE and isinstance(obj.hparams, Container): + checkpoint[obj.CHECKPOINT_HYPER_PARAMS_KEY] = obj.hparams + checkpoint[obj.CHECKPOINT_HYPER_PARAMS_TYPE] = type(obj.hparams) + else: + checkpoint[obj.CHECKPOINT_HYPER_PARAMS_KEY] = dict(obj.hparams) # dump stateful datamodule datamodule = self.trainer.datamodule From 3db0775f4c6ef4d183fcb3d8a15c9bf40cd3b6e0 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 26 Apr 2022 16:03:01 +0530 Subject: [PATCH 9/9] docs --- docs/source/common/checkpointing_basic.rst | 1 + pytorch_lightning/core/datamodule.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/common/checkpointing_basic.rst b/docs/source/common/checkpointing_basic.rst index 899de91a1b0ed..6ff54c94245d2 100644 --- a/docs/source/common/checkpointing_basic.rst +++ b/docs/source/common/checkpointing_basic.rst @@ -36,6 +36,7 @@ Inside a Lightning checkpoint you'll find: - State of all callbacks (for stateful callbacks) - State of datamodule (for stateful datamodules) - The hyperparameters used for that model if passed in as hparams (Argparse.Namespace) +- The hyperparameters used for that datamodule if passed in as hparams (Argparse.Namespace) - State of Loops (if using Fault-Tolerant training) ---- diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 7e99b3a705a8a..cddbddefb02bd 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -20,7 +20,6 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.core.mixins import HyperparametersMixin from pytorch_lightning.core.saving import _load_from_checkpoint -from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types from pytorch_lightning.utilities.types import _PATH