diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 97bb317b02a14..04db3d1908bb2 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -70,6 +70,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061)) +- Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151)) + + + ## [1.7.1] - 2022-08-09 ### Fixed diff --git a/src/pytorch_lightning/utilities/parsing.py b/src/pytorch_lightning/utilities/parsing.py index 81877f1dffba7..073423ab60773 100644 --- a/src/pytorch_lightning/utilities/parsing.py +++ b/src/pytorch_lightning/utilities/parsing.py @@ -162,7 +162,10 @@ def get_init_args(frame: types.FrameType) -> Dict[str, Any]: def collect_init_args( - frame: types.FrameType, path_args: List[Dict[str, Any]], inside: bool = False + frame: types.FrameType, + path_args: List[Dict[str, Any]], + inside: bool = False, + classes: Tuple[Type, ...] = (), ) -> List[Dict[str, Any]]: """Recursively collects the arguments passed to the child constructors in the inheritance tree. @@ -170,6 +173,7 @@ def collect_init_args( frame: the current stack frame path_args: a list of dictionaries containing the constructor args in all parent classes inside: track if we are inside inheritance path, avoid terminating too soon + classes: the classes in which to inspect the frames Return: A list of dictionaries where each dictionary contains the arguments passed to the @@ -181,13 +185,13 @@ def collect_init_args( if not isinstance(frame.f_back, types.FrameType): return path_args - if "__class__" in local_vars: + if "__class__" in local_vars and (not classes or issubclass(local_vars["__class__"], classes)): local_args = get_init_args(frame) # recursive update path_args.append(local_args) - return collect_init_args(frame.f_back, path_args, inside=True) + return collect_init_args(frame.f_back, path_args, inside=True, classes=classes) if not inside: - return collect_init_args(frame.f_back, path_args, inside) + return collect_init_args(frame.f_back, path_args, inside, classes=classes) return path_args @@ -225,7 +229,10 @@ def save_hyperparameters( init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} else: init_args = {} - for local_args in collect_init_args(frame, []): + + from pytorch_lightning.core.mixins import HyperparametersMixin + + for local_args in collect_init_args(frame, [], classes=(HyperparametersMixin,)): init_args.update(local_args) if ignore is None: diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index c130381c7832d..84311d6f780fb 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -29,6 +29,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.core.mixins import HyperparametersMixin from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict, is_picklable @@ -399,6 +400,24 @@ def _raw_checkpoint_path(trainer) -> str: return raw_checkpoint_path +@pytest.mark.parametrize("base_class", (HyperparametersMixin, LightningModule, LightningDataModule)) +def test_save_hyperparameters_under_composition(base_class): + """Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get + collected.""" + + class ChildInComposition(base_class): + def __init__(self, same_arg): + super().__init__() + self.save_hyperparameters() + + class NotPLSubclass: # intentionally not subclassing LightningModule/LightningDataModule + def __init__(self, same_arg="parent_default", other_arg="other"): + self.child = ChildInComposition(same_arg="cocofruit") + + parent = NotPLSubclass() + assert parent.child.hparams == dict(same_arg="cocofruit") + + class LocalVariableModelSuperLast(BoringModel): """This model has the super().__init__() call at the end."""