Skip to content

Commit

Permalink
allow decorate model init with saving hparams (#4662)
Browse files Browse the repository at this point in the history
* addd tests

* use boring model

* parsing init

* chlog

* double decorate

* Apply suggestions from code review

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* bug

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
  • Loading branch information
5 people committed Nov 17, 2020
1 parent 1490075 commit 52a0781
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `setup` callback hook to correctly pass the LightningModule through ([#4608](https://github.com/PyTorchLightning/pytorch-lightning/pull/4608))
- Fixed docs typo ([#4659](https://github.com/PyTorchLightning/pytorch-lightning/pull/4659), [#4670](https://github.com/PyTorchLightning/pytorch-lightning/pull/4670))
- Fixed notebooks typo ([#4657](https://github.com/PyTorchLightning/pytorch-lightning/pull/4657))
- Allowing decorate model init with saving `hparams` inside ([#4662](https://github.com/PyTorchLightning/pytorch-lightning/pull/4662))



Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@
import inspect
import os
from argparse import Namespace
from typing import Union, Dict, Any, Optional, Callable, MutableMapping, IO
from typing import IO, Any, Callable, Dict, MutableMapping, Optional, Union
from warnings import warn

import fsspec
import torch
import yaml

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities import AttributeDict, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem

from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.parsing import parse_class_init_keys

PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
Expand Down Expand Up @@ -159,8 +159,8 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cl
cls_spec = inspect.getfullargspec(cls.__init__)
cls_init_args_name = inspect.signature(cls.__init__).parameters.keys()

self_name = cls_spec.args[0]
drop_names = (self_name, cls_spec.varargs, cls_spec.varkw)
self_var, args_var, kwargs_var = parse_class_init_keys(cls)
drop_names = [n for n in (self_var, args_var, kwargs_var) if n]
cls_init_args_name = list(filter(lambda n: n not in drop_names, cls_init_args_name))

cls_kwargs_loaded = {}
Expand Down
43 changes: 33 additions & 10 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import inspect
import pickle
from argparse import Namespace
from typing import Dict, Union
from typing import Dict, Tuple, Union

from pytorch_lightning.utilities import rank_zero_warn

Expand Down Expand Up @@ -79,23 +79,46 @@ def clean_namespace(hparams):
del hparams_dict[k]


def parse_class_init_keys(cls) -> Tuple[str, str, str]:
"""Parse key words for standard self, *args and **kwargs
>>> class Model():
... def __init__(self, hparams, *my_args, anykw=42, **my_kwargs):
... pass
>>> parse_class_init_keys(Model)
('self', 'my_args', 'my_kwargs')
"""
init_parameters = inspect.signature(cls.__init__).parameters
# docs claims the params are always ordered
# https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
init_params = list(init_parameters.values())
# self is always first
n_self = init_params[0].name

def _get_first_if_any(params, param_type):
for p in params:
if p.kind == param_type:
return p.name

n_args = _get_first_if_any(init_params, inspect.Parameter.VAR_POSITIONAL)
n_kwargs = _get_first_if_any(init_params, inspect.Parameter.VAR_KEYWORD)

return n_self, n_args, n_kwargs


def get_init_args(frame) -> dict:
_, _, _, local_vars = inspect.getargvalues(frame)
if '__class__' not in local_vars:
return
return {}
cls = local_vars['__class__']
spec = inspect.getfullargspec(cls.__init__)
init_parameters = inspect.signature(cls.__init__).parameters
self_identifier = spec.args[0] # "self" unless user renames it (always first arg)
varargs_identifier = spec.varargs # by convention this is named "*args"
kwargs_identifier = spec.varkw # by convention this is named "**kwargs"
exclude_argnames = (
varargs_identifier, kwargs_identifier, self_identifier, '__class__', 'frame', 'frame_args'
)
self_var, args_var, kwargs_var = parse_class_init_keys(cls)
filtered_vars = [n for n in (self_var, args_var, kwargs_var) if n]
exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args')

# only collect variables that appear in the signature
local_args = {k: local_vars[k] for k in init_parameters.keys()}
local_args.update(local_args.get(kwargs_identifier, {}))
local_args.update(local_args.get(kwargs_var, {}))
local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames}
return local_args

Expand Down
51 changes: 42 additions & 9 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import pickle
from argparse import Namespace
Expand All @@ -19,30 +20,56 @@
import pytest
import torch
from fsspec.implementations.local import LocalFileSystem
from omegaconf import OmegaConf, Container
from omegaconf import Container, OmegaConf
from torch.nn import functional as F
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
from pytorch_lightning.utilities import AttributeDict, is_picklable
from tests.base import EvalModelTemplate, TrialMNIST, BoringModel
from tests.base import BoringModel, EvalModelTemplate, TrialMNIST


class SaveHparamsModel(EvalModelTemplate):
class SaveHparamsModel(BoringModel):
""" Tests that a model can take an object """
def __init__(self, hparams):
super().__init__()
self.save_hyperparameters(hparams)


class AssignHparamsModel(EvalModelTemplate):
class AssignHparamsModel(BoringModel):
""" Tests that a model can take an object with explicit setter """
def __init__(self, hparams):
super().__init__()
self.hparams = hparams


def decorate(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper


class SaveHparamsDecoratedModel(BoringModel):
""" Tests that a model can take an object """
@decorate
@decorate
def __init__(self, hparams, *my_args, **my_kwargs):
super().__init__()
self.save_hyperparameters(hparams)


class AssignHparamsDecoratedModel(BoringModel):
""" Tests that a model can take an object with explicit setter"""
@decorate
@decorate
def __init__(self, hparams, *my_args, **my_kwargs):
super().__init__()
self.hparams = hparams


# -------------------------
# STANDARD TESTS
# -------------------------
Expand Down Expand Up @@ -78,7 +105,9 @@ def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False):
return raw_checkpoint_path


@pytest.mark.parametrize("cls", [SaveHparamsModel, AssignHparamsModel])
@pytest.mark.parametrize("cls", [
SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel
])
def test_namespace_hparams(tmpdir, cls):
# init model
model = cls(hparams=Namespace(test_arg=14))
Expand All @@ -87,7 +116,9 @@ def test_namespace_hparams(tmpdir, cls):
_run_standard_hparams_test(tmpdir, model, cls)


@pytest.mark.parametrize("cls", [SaveHparamsModel, AssignHparamsModel])
@pytest.mark.parametrize("cls", [
SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel
])
def test_dict_hparams(tmpdir, cls):
# init model
model = cls(hparams={'test_arg': 14})
Expand All @@ -96,7 +127,9 @@ def test_dict_hparams(tmpdir, cls):
_run_standard_hparams_test(tmpdir, model, cls)


@pytest.mark.parametrize("cls", [SaveHparamsModel, AssignHparamsModel])
@pytest.mark.parametrize("cls", [
SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel
])
def test_omega_conf_hparams(tmpdir, cls):
# init model
conf = OmegaConf.create(dict(test_arg=14, mylist=[15.4, dict(a=1, b=2)]))
Expand Down

0 comments on commit 52a0781

Please sign in to comment.