Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jun 4, 2020
1 parent 0662707 commit c7e0b74
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1756,10 +1756,11 @@ def module_arguments(self) -> dict:
custom object or dict in which the keys are the union of all argument names in the constructor
and all parent constructors, excluding `self`, `*args` and `**kwargs`.
"""
args = copy.deepcopy(self._module_parents_arguments)
if isinstance(args, dict):
if isinstance(self._module_self_arguments, dict):
args = copy.deepcopy(self._module_parents_arguments)
args.update(self._module_self_arguments)
return args
return args
return copy.deepcopy(self._module_self_arguments)

def save_hyperparameters(self, *args, **kwargs) -> None:
"""
Expand Down Expand Up @@ -1787,6 +1788,18 @@ def save_hyperparameters(self, *args, **kwargs) -> None:
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
>>> OrderedDict(model.module_arguments)
OrderedDict([('arg1', 1), ('arg2', 'abc'), ('arg3', 3.14)])
>>> from collections import OrderedDict
>>> class SingleArgModel(LightningModule):
... def __init__(self, hparams):
... super().__init__()
... # manually assign single argument
... self.save_hyperparameters(hparams)
... def forward(self, *args, **kwargs):
... ...
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
>>> model.module_arguments
Namespace(p1=1, p2='abc', p3=3.14)
"""
if not args and not kwargs:
self._auto_collect_arguments()
Expand Down

0 comments on commit c7e0b74

Please sign in to comment.