From ba0d84f14105e7003dda6fd80228b19ac7bfe365 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 10 Aug 2022 19:59:51 +0200 Subject: [PATCH 1/6] bugfix --- src/pytorch_lightning/utilities/parsing.py | 14 +++++++++----- tests/tests_pytorch/models/test_hparams.py | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/utilities/parsing.py b/src/pytorch_lightning/utilities/parsing.py index 81877f1dffba7..a30b167331846 100644 --- a/src/pytorch_lightning/utilities/parsing.py +++ b/src/pytorch_lightning/utilities/parsing.py @@ -162,7 +162,10 @@ def get_init_args(frame: types.FrameType) -> Dict[str, Any]: def collect_init_args( - frame: types.FrameType, path_args: List[Dict[str, Any]], inside: bool = False + frame: types.FrameType, + path_args: List[Dict[str, Any]], + inside: bool = False, + classes: Tuple[Type, ...] = (object,), ) -> List[Dict[str, Any]]: """Recursively collects the arguments passed to the child constructors in the inheritance tree. @@ -170,6 +173,7 @@ def collect_init_args( frame: the current stack frame path_args: a list of dictionaries containing the constructor args in all parent classes inside: track if we are inside inheritance path, avoid terminating too soon + classes: the classes in which to inspect the frames Return: A list of dictionaries where each dictionary contains the arguments passed to the @@ -181,13 +185,13 @@ def collect_init_args( if not isinstance(frame.f_back, types.FrameType): return path_args - if "__class__" in local_vars: + if "__class__" in local_vars and issubclass(local_vars["__class__"], classes): local_args = get_init_args(frame) # recursive update path_args.append(local_args) - return collect_init_args(frame.f_back, path_args, inside=True) + return collect_init_args(frame.f_back, path_args, inside=True, classes=classes) if not inside: - return collect_init_args(frame.f_back, path_args, inside) + return collect_init_args(frame.f_back, path_args, inside, classes=classes) return path_args @@ -225,7 +229,7 @@ def save_hyperparameters( init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} else: init_args = {} - for local_args in collect_init_args(frame, []): + for local_args in collect_init_args(frame, [], classes=(pl.LightningModule, pl.LightningDataModule)): init_args.update(local_args) if ignore is None: diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index c130381c7832d..a986cb72fab11 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -399,6 +399,25 @@ def _raw_checkpoint_path(trainer) -> str: return raw_checkpoint_path +@pytest.mark.parametrize("base_class", (LightningModule, LightningDataModule)) +def test_save_hyperparameters_under_composition(base_class): + """Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get + collected.""" + + class ChildInComposition(base_class): + def __init__(self, same_arg): + super().__init__() + self.save_hyperparameters() + + class NotPLSubclass: # intentionally not subclassing LightningModule/LightningDataModule + def __init__(self, same_arg="parent_default", other_arg="other"): + super().__init__() + self.child = ChildInComposition(same_arg="cocofruit") + + parent = NotPLSubclass() + assert parent.child.hparams == dict(same_arg="cocofruit") + + class LocalVariableModelSuperLast(BoringModel): """This model has the super().__init__() call at the end.""" From 4f473d2d5074a294b5ce428735432c660ee7bc84 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 10 Aug 2022 20:04:09 +0200 Subject: [PATCH 2/6] update changelog --- src/pytorch_lightning/CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 97bb317b02a14..277ec5389dc40 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -70,6 +70,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061)) +- Fix saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151)) + + + ## [1.7.1] - 2022-08-09 ### Fixed From e5ba505a02d639b5b84dbaa70a9a374e3b8dc8bc Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 10 Aug 2022 20:06:34 +0200 Subject: [PATCH 3/6] remove redundant init --- tests/tests_pytorch/models/test_hparams.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index a986cb72fab11..1fc5aa75423d5 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -411,7 +411,6 @@ def __init__(self, same_arg): class NotPLSubclass: # intentionally not subclassing LightningModule/LightningDataModule def __init__(self, same_arg="parent_default", other_arg="other"): - super().__init__() self.child = ChildInComposition(same_arg="cocofruit") parent = NotPLSubclass() From 162bc928f65a6474b12f5b66da9f8ce3d77a7b5c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 10 Aug 2022 20:21:51 +0200 Subject: [PATCH 4/6] update --- src/pytorch_lightning/utilities/parsing.py | 5 ++++- tests/tests_pytorch/models/test_hparams.py | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/parsing.py b/src/pytorch_lightning/utilities/parsing.py index a30b167331846..c4a73ec571108 100644 --- a/src/pytorch_lightning/utilities/parsing.py +++ b/src/pytorch_lightning/utilities/parsing.py @@ -229,7 +229,10 @@ def save_hyperparameters( init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} else: init_args = {} - for local_args in collect_init_args(frame, [], classes=(pl.LightningModule, pl.LightningDataModule)): + + from pytorch_lightning.core.mixins import HyperparametersMixin + + for local_args in collect_init_args(frame, [], classes=(HyperparametersMixin,)): init_args.update(local_args) if ignore is None: diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 1fc5aa75423d5..84311d6f780fb 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -29,6 +29,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.core.mixins import HyperparametersMixin from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict, is_picklable @@ -399,7 +400,7 @@ def _raw_checkpoint_path(trainer) -> str: return raw_checkpoint_path -@pytest.mark.parametrize("base_class", (LightningModule, LightningDataModule)) +@pytest.mark.parametrize("base_class", (HyperparametersMixin, LightningModule, LightningDataModule)) def test_save_hyperparameters_under_composition(base_class): """Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get collected.""" From 4dd9c2bb141c7bc168484e0602ba80ec5fad0cf0 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 11 Aug 2022 01:48:35 +0200 Subject: [PATCH 5/6] code smell? --- src/pytorch_lightning/utilities/parsing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/parsing.py b/src/pytorch_lightning/utilities/parsing.py index c4a73ec571108..073423ab60773 100644 --- a/src/pytorch_lightning/utilities/parsing.py +++ b/src/pytorch_lightning/utilities/parsing.py @@ -165,7 +165,7 @@ def collect_init_args( frame: types.FrameType, path_args: List[Dict[str, Any]], inside: bool = False, - classes: Tuple[Type, ...] = (object,), + classes: Tuple[Type, ...] = (), ) -> List[Dict[str, Any]]: """Recursively collects the arguments passed to the child constructors in the inheritance tree. @@ -185,7 +185,7 @@ def collect_init_args( if not isinstance(frame.f_back, types.FrameType): return path_args - if "__class__" in local_vars and issubclass(local_vars["__class__"], classes): + if "__class__" in local_vars and (not classes or issubclass(local_vars["__class__"], classes)): local_args = get_init_args(frame) # recursive update path_args.append(local_args) From 70ca4ad4e0ee703e5ccc084af24a324483d718d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 11 Aug 2022 11:14:15 -0400 Subject: [PATCH 6/6] Update src/pytorch_lightning/CHANGELOG.md Co-authored-by: Rohit Gupta --- src/pytorch_lightning/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 277ec5389dc40..04db3d1908bb2 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -70,7 +70,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061)) -- Fix saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151)) +- Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151))