Skip to content

Commit

Permalink
Fix local variables being collected into module_arguments dict (#2048)
Browse files Browse the repository at this point in the history
* do not include local vars in auto collection

* add test

* add test for model with "self" renamed to "obj"

* skip decorator

* changelog

* changelog

* update docs

* remove obsolete child collection

* generalize **args, **kwargs names

* docs

* also update varargs passed in

* Revert "also update varargs passed in"

This reverts commit 3d7a30d.

* update test
  • Loading branch information
awaelchli authored Jun 4, 2020
1 parent fd7814d commit 4234992
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 36 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Allow use of same `WandbLogger` instance for multiple training loops ([#2055](https://github.com/PyTorchLightning/pytorch-lightning/pull/2055))

- Fixed an issue where local variables were being collected into module_arguments ([#2048](https://github.com/PyTorchLightning/pytorch-lightning/pull/2048))

- Fixed an issue with `auto_collect_arguments` collecting local variables that are not constructor arguments and not working for signatures that have the instance not named `self` ([#2048](https://github.com/PyTorchLightning/pytorch-lightning/pull/2048))

## [0.7.6] - 2020-05-16

### Added
Expand Down
73 changes: 48 additions & 25 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1698,22 +1698,34 @@ def get_tqdm_dict(self) -> Dict[str, Union[int, str]]:
" and this method will be removed in v1.0.0", DeprecationWarning)
return self.get_progress_bar_dict()

def auto_collect_arguments(self):
"""Collect all arguments module arguments."""
def auto_collect_arguments(self) -> None:
"""
Collect all module arguments in the current constructor and all child constructors.
The child constructors are all the ``__init__`` methods that reach the current class through
(chained) ``super().__init__()`` calls.
"""
frame = inspect.currentframe()

frame_args = _collect_init_args(frame.f_back, [])
child = _get_latest_child(frame)
self_arguments = frame_args[-1]

# set module_arguments in child
child._module_self_arguments = frame_args[-1]
child._module_parents_arguments = {}
self._module_self_arguments = self_arguments
self._module_parents_arguments = {}

# add all arguments from parents
for args in frame_args[:-1]:
child._module_parents_arguments.update(args)
self._module_parents_arguments.update(args)

@property
def module_arguments(self) -> dict:
"""Aggregate this module and all parents arguments."""
"""
Aggregate of arguments passed to the constructor of this module and all parents.
Return:
a dict in which the keys are the union of all argument names in the constructor and all
parent constructors, excluding `self`, `*args` and `**kwargs`.
"""
try:
args = dict(self._module_parents_arguments)
args.update(self._module_self_arguments)
Expand All @@ -1724,26 +1736,37 @@ def module_arguments(self) -> dict:


def _collect_init_args(frame, path_args: list) -> list:
"""Recursive search for all children."""
if '__class__' in frame.f_locals:
local_args = dict(frame.f_locals)
local_args.update(local_args.get('kwargs', {}))
local_args = {k: v for k, v in local_args.items()
if k not in ('args', 'kwargs', 'self', '__class__', 'frame', 'frame_args')}
# if 'hparams' in local_args:
# # back compatible hparams as single argument
# hparams = local_args.get('hparams')
# local_args.update(vars(hparams) if isinstance(hparams, Namespace) else hparams)
"""
Recursively collects the arguments passed to the child constructors in the inheritance tree.
Args:
frame: the current stack frame
path_args: a list of dictionaries containing the constructor args in all parent classes
Return:
A list of dictionaries where each dictionary contains the arguments passed to the
constructor at that level. The last entry corresponds to the constructor call of the
most specific class in the hierarchy.
"""
_, _, _, local_vars = inspect.getargvalues(frame)
if '__class__' in local_vars:
cls = local_vars['__class__']
spec = inspect.getfullargspec(cls.__init__)
init_parameters = inspect.signature(cls.__init__).parameters
self_identifier = spec.args[0] # "self" unless user renames it (always first arg)
varargs_identifier = spec.varargs # by convention this is named "*args"
kwargs_identifier = spec.varkw # by convention this is named "**kwargs"
exclude_argnames = (
varargs_identifier, kwargs_identifier, self_identifier, '__class__', 'frame', 'frame_args'
)

# only collect variables that appear in the signature
local_args = {k: local_vars[k] for k in init_parameters.keys()}
local_args.update(local_args.get(kwargs_identifier, {}))
local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames}

# recursive update
path_args.append(local_args)
return _collect_init_args(frame.f_back, path_args)
else:
return path_args


def _get_latest_child(frame, child: object = None) -> object:
"""Recursive search for lowest child."""
if 'self' in frame.f_locals:
return _get_latest_child(frame.f_back, frame.f_locals['self'])
else:
return child
63 changes: 52 additions & 11 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,8 @@ def test2(self):
A().test()


@pytest.mark.skipif(sys.version_info < (3, 8), reason='OmegaConf only for Python >= 3.8')
def test_omegaconf(tmpdir):

# ogc only for 3.8
major = sys.version_info[0]
minor = sys.version_info[1]
if major < 3 and minor < 8:
return

conf = OmegaConf.create({"k": "v", "list": [15.4, {"a": "1", "b": "2"}]})
model = OmegaConfModel(conf)

Expand All @@ -73,6 +67,17 @@ def __init__(self, *args, subclass_arg=1200, **kwargs):
self.auto_collect_arguments()


class UnconventionalArgsEvalModel(EvalModelTemplate):
""" A model that has unconventional names for "self", "*args" and "**kwargs". """

def __init__(obj, *more_args, other_arg=300, **more_kwargs):
# intentionally named obj
super().__init__(*more_args, **more_kwargs)
obj.other_arg = other_arg
other_arg = 321
obj.auto_collect_arguments()


class SubSubClassEvalModel(SubClassEvalModel):
pass

Expand All @@ -85,10 +90,13 @@ def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs):
self.auto_collect_arguments()


@pytest.mark.parametrize("cls", [EvalModelTemplate,
SubClassEvalModel,
SubSubClassEvalModel,
AggSubClassEvalModel])
@pytest.mark.parametrize("cls", [
EvalModelTemplate,
SubClassEvalModel,
SubSubClassEvalModel,
AggSubClassEvalModel,
UnconventionalArgsEvalModel,
])
def test_collect_init_arguments(tmpdir, cls):
""" Test that the model automatically saves the arguments passed into the constructor """
extra_args = dict(my_loss=torch.nn.CosineEmbeddingLoss()) if cls is AggSubClassEvalModel else {}
Expand Down Expand Up @@ -125,3 +133,36 @@ def test_collect_init_arguments(tmpdir, cls):
# verify that we can overwrite whatever we want
model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99)
assert model.batch_size == 99


class LocalVariableModel1(EvalModelTemplate):
""" This model has the super().__init__() call at the end. """

def __init__(self, arg1, arg2, *args, **kwargs):
self.argument1 = arg1 # arg2 intentionally not set
arg1 = 'overwritten'
local_var = 1234
super().__init__(*args, **kwargs) # this is intentionally here at the end


class LocalVariableModel2(EvalModelTemplate):
""" This model has the auto_collect_arguments() call at the end. """

def __init__(self, arg1, arg2, *args, **kwargs):
super().__init__(*args, **kwargs)
self.argument1 = arg1 # arg2 intentionally not set
arg1 = 'overwritten'
local_var = 1234
self.auto_collect_arguments() # this is intentionally here at the end


@pytest.mark.parametrize("cls", [
LocalVariableModel1,
LocalVariableModel2,
])
def test_collect_init_arguments_with_local_vars(cls):
""" Tests that only the arguments are collected and not local variables. """
model = cls(arg1=1, arg2=2)
assert 'local_var' not in model.module_arguments
assert model.module_arguments['arg1'] == 'overwritten'
assert model.module_arguments['arg2'] == 2

0 comments on commit 4234992

Please sign in to comment.