diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e06ccc6abe1b..6b879a6dfdb6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `setup` callback hook to correctly pass the LightningModule through ([#4608](https://github.com/PyTorchLightning/pytorch-lightning/pull/4608)) - Fixed docs typo ([#4659](https://github.com/PyTorchLightning/pytorch-lightning/pull/4659), [#4670](https://github.com/PyTorchLightning/pytorch-lightning/pull/4670)) - Fixed notebooks typo ([#4657](https://github.com/PyTorchLightning/pytorch-lightning/pull/4657)) +- Allowing decorate model init with saving `hparams` inside ([#4662](https://github.com/PyTorchLightning/pytorch-lightning/pull/4662)) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 2662aa6758332..bd76366121a63 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -17,7 +17,7 @@ import inspect import os from argparse import Namespace -from typing import Union, Dict, Any, Optional, Callable, MutableMapping, IO +from typing import IO, Any, Callable, Dict, MutableMapping, Optional, Union from warnings import warn import fsspec @@ -25,10 +25,10 @@ import yaml from pytorch_lightning import _logger as log -from pytorch_lightning.utilities import rank_zero_warn, AttributeDict -from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities import AttributeDict, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem - +from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.parsing import parse_class_init_keys PRIMITIVE_TYPES = (bool, int, float, str) ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace) @@ -159,8 +159,8 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cl cls_spec = inspect.getfullargspec(cls.__init__) cls_init_args_name = inspect.signature(cls.__init__).parameters.keys() - self_name = cls_spec.args[0] - drop_names = (self_name, cls_spec.varargs, cls_spec.varkw) + self_var, args_var, kwargs_var = parse_class_init_keys(cls) + drop_names = [n for n in (self_var, args_var, kwargs_var) if n] cls_init_args_name = list(filter(lambda n: n not in drop_names, cls_init_args_name)) cls_kwargs_loaded = {} diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 348eec110c3a1..521dd5200521a 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -15,7 +15,7 @@ import inspect import pickle from argparse import Namespace -from typing import Dict, Union +from typing import Dict, Tuple, Union from pytorch_lightning.utilities import rank_zero_warn @@ -79,23 +79,46 @@ def clean_namespace(hparams): del hparams_dict[k] +def parse_class_init_keys(cls) -> Tuple[str, str, str]: + """Parse key words for standard self, *args and **kwargs + + >>> class Model(): + ... def __init__(self, hparams, *my_args, anykw=42, **my_kwargs): + ... pass + >>> parse_class_init_keys(Model) + ('self', 'my_args', 'my_kwargs') + """ + init_parameters = inspect.signature(cls.__init__).parameters + # docs claims the params are always ordered + # https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters + init_params = list(init_parameters.values()) + # self is always first + n_self = init_params[0].name + + def _get_first_if_any(params, param_type): + for p in params: + if p.kind == param_type: + return p.name + + n_args = _get_first_if_any(init_params, inspect.Parameter.VAR_POSITIONAL) + n_kwargs = _get_first_if_any(init_params, inspect.Parameter.VAR_KEYWORD) + + return n_self, n_args, n_kwargs + + def get_init_args(frame) -> dict: _, _, _, local_vars = inspect.getargvalues(frame) if '__class__' not in local_vars: - return + return {} cls = local_vars['__class__'] - spec = inspect.getfullargspec(cls.__init__) init_parameters = inspect.signature(cls.__init__).parameters - self_identifier = spec.args[0] # "self" unless user renames it (always first arg) - varargs_identifier = spec.varargs # by convention this is named "*args" - kwargs_identifier = spec.varkw # by convention this is named "**kwargs" - exclude_argnames = ( - varargs_identifier, kwargs_identifier, self_identifier, '__class__', 'frame', 'frame_args' - ) + self_var, args_var, kwargs_var = parse_class_init_keys(cls) + filtered_vars = [n for n in (self_var, args_var, kwargs_var) if n] + exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args') # only collect variables that appear in the signature local_args = {k: local_vars[k] for k in init_parameters.keys()} - local_args.update(local_args.get(kwargs_identifier, {})) + local_args.update(local_args.get(kwargs_var, {})) local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames} return local_args diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index b7d0be01e9622..908a1ea757a15 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools import os import pickle from argparse import Namespace @@ -19,30 +20,56 @@ import pytest import torch from fsspec.implementations.local import LocalFileSystem -from omegaconf import OmegaConf, Container +from omegaconf import Container, OmegaConf from torch.nn import functional as F from torch.utils.data import DataLoader -from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml from pytorch_lightning.utilities import AttributeDict, is_picklable -from tests.base import EvalModelTemplate, TrialMNIST, BoringModel +from tests.base import BoringModel, EvalModelTemplate, TrialMNIST -class SaveHparamsModel(EvalModelTemplate): +class SaveHparamsModel(BoringModel): """ Tests that a model can take an object """ def __init__(self, hparams): super().__init__() self.save_hyperparameters(hparams) -class AssignHparamsModel(EvalModelTemplate): +class AssignHparamsModel(BoringModel): """ Tests that a model can take an object with explicit setter """ def __init__(self, hparams): super().__init__() self.hparams = hparams +def decorate(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +class SaveHparamsDecoratedModel(BoringModel): + """ Tests that a model can take an object """ + @decorate + @decorate + def __init__(self, hparams, *my_args, **my_kwargs): + super().__init__() + self.save_hyperparameters(hparams) + + +class AssignHparamsDecoratedModel(BoringModel): + """ Tests that a model can take an object with explicit setter""" + @decorate + @decorate + def __init__(self, hparams, *my_args, **my_kwargs): + super().__init__() + self.hparams = hparams + + # ------------------------- # STANDARD TESTS # ------------------------- @@ -78,7 +105,9 @@ def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False): return raw_checkpoint_path -@pytest.mark.parametrize("cls", [SaveHparamsModel, AssignHparamsModel]) +@pytest.mark.parametrize("cls", [ + SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel +]) def test_namespace_hparams(tmpdir, cls): # init model model = cls(hparams=Namespace(test_arg=14)) @@ -87,7 +116,9 @@ def test_namespace_hparams(tmpdir, cls): _run_standard_hparams_test(tmpdir, model, cls) -@pytest.mark.parametrize("cls", [SaveHparamsModel, AssignHparamsModel]) +@pytest.mark.parametrize("cls", [ + SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel +]) def test_dict_hparams(tmpdir, cls): # init model model = cls(hparams={'test_arg': 14}) @@ -96,7 +127,9 @@ def test_dict_hparams(tmpdir, cls): _run_standard_hparams_test(tmpdir, model, cls) -@pytest.mark.parametrize("cls", [SaveHparamsModel, AssignHparamsModel]) +@pytest.mark.parametrize("cls", [ + SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel +]) def test_omega_conf_hparams(tmpdir, cls): # init model conf = OmegaConf.create(dict(test_arg=14, mylist=[15.4, dict(a=1, b=2)]))