diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index a6ea1d8467a0d..d0447dc7d1c66 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -6,7 +6,7 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info -from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict +from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable try: from apex import amp diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index ba0c768265166..1422a8ac62f03 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -13,9 +13,12 @@ # limitations under the License. import inspect +import pickle from argparse import Namespace from typing import Dict +from pytorch_lightning.utilities import rank_zero_warn + def str_to_bool(val): """Convert a string representation of truth to true (1) or false (0). @@ -39,26 +42,28 @@ def str_to_bool(val): raise ValueError(f'invalid truth value {val}') +def is_picklable(obj: object) -> bool: + """Tests if an object can be pickled""" + + try: + pickle.dumps(obj) + return True + except pickle.PicklingError: + return False + + def clean_namespace(hparams): - """Removes all functions from hparams so we can pickle.""" + """Removes all unpicklable entries from hparams""" + hparams_dict = hparams if isinstance(hparams, Namespace): - del_attrs = [] - for k in hparams.__dict__: - if callable(getattr(hparams, k)): - del_attrs.append(k) - - for k in del_attrs: - delattr(hparams, k) - - elif isinstance(hparams, dict): - del_attrs = [] - for k, v in hparams.items(): - if callable(v): - del_attrs.append(k) - - for k in del_attrs: - del hparams[k] + hparams_dict = hparams.__dict__ + + del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)] + + for k in del_attrs: + rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled", UserWarning) + del hparams_dict[k] def get_init_args(frame) -> dict: diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index a389aad5cdd67..807d5dcc869fe 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -11,7 +11,7 @@ from pytorch_lightning import Trainer, LightningModule from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml -from pytorch_lightning.utilities import AttributeDict +from pytorch_lightning.utilities import AttributeDict, is_picklable from tests.base import EvalModelTemplate, TrialMNIST @@ -282,7 +282,7 @@ def test_collect_init_arguments(tmpdir, cls): assert model.hparams.batch_size == 179 if isinstance(model, AggSubClassEvalModel): - assert isinstance(model.hparams.my_loss, torch.nn.CrossEntropyLoss) + assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss) if isinstance(model, DictConfSubClassEvalModel): assert isinstance(model.hparams.dict_conf, Container) @@ -413,6 +413,23 @@ def test_hparams_pickle(tmpdir): assert ad == pickle.loads(pkl) +class UnpickleableArgsEvalModel(EvalModelTemplate): + """ A model that has an attribute that cannot be pickled. """ + + def __init__(self, foo='bar', pickle_me=(lambda x: x + 1), **kwargs): + super().__init__(**kwargs) + assert not is_picklable(pickle_me) + self.save_hyperparameters() + + +def test_hparams_pickle_warning(tmpdir): + model = UnpickleableArgsEvalModel() + trainer = Trainer(default_root_dir=tmpdir, max_steps=1) + with pytest.warns(UserWarning, match="attribute 'pickle_me' removed from hparams because it cannot be pickled"): + trainer.fit(model) + assert 'pickle_me' not in model.hparams + + def test_hparams_save_yaml(tmpdir): hparams = dict(batch_size=32, learning_rate=0.001, data_root='./any/path/here', nasted=dict(any_num=123, anystr='abcd'))