Skip to content

Commit

Permalink
warn user when dropping unpicklable hparams (#2874)
Browse files Browse the repository at this point in the history
* refactored clean_namespace

* Update try except to handle pickling error

* Consolidated clean_namespace. Added is_picklable

* PEP8

* Change warning to use rank_zero_warn. Added Test to ensure proper hparam filtering

* Updated imports

* Corrected Test Case
  • Loading branch information
monney authored Aug 28, 2020
1 parent 85cd558 commit d5254ff
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable

try:
from apex import amp
Expand Down
39 changes: 22 additions & 17 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# limitations under the License.

import inspect
import pickle
from argparse import Namespace
from typing import Dict

from pytorch_lightning.utilities import rank_zero_warn


def str_to_bool(val):
"""Convert a string representation of truth to true (1) or false (0).
Expand All @@ -39,26 +42,28 @@ def str_to_bool(val):
raise ValueError(f'invalid truth value {val}')


def is_picklable(obj: object) -> bool:
"""Tests if an object can be pickled"""

try:
pickle.dumps(obj)
return True
except pickle.PicklingError:
return False


def clean_namespace(hparams):
"""Removes all functions from hparams so we can pickle."""
"""Removes all unpicklable entries from hparams"""

hparams_dict = hparams
if isinstance(hparams, Namespace):
del_attrs = []
for k in hparams.__dict__:
if callable(getattr(hparams, k)):
del_attrs.append(k)

for k in del_attrs:
delattr(hparams, k)

elif isinstance(hparams, dict):
del_attrs = []
for k, v in hparams.items():
if callable(v):
del_attrs.append(k)

for k in del_attrs:
del hparams[k]
hparams_dict = hparams.__dict__

del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)]

for k in del_attrs:
rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled", UserWarning)
del hparams_dict[k]


def get_init_args(frame) -> dict:
Expand Down
21 changes: 19 additions & 2 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities import AttributeDict, is_picklable
from tests.base import EvalModelTemplate, TrialMNIST


Expand Down Expand Up @@ -282,7 +282,7 @@ def test_collect_init_arguments(tmpdir, cls):
assert model.hparams.batch_size == 179

if isinstance(model, AggSubClassEvalModel):
assert isinstance(model.hparams.my_loss, torch.nn.CrossEntropyLoss)
assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss)

if isinstance(model, DictConfSubClassEvalModel):
assert isinstance(model.hparams.dict_conf, Container)
Expand Down Expand Up @@ -413,6 +413,23 @@ def test_hparams_pickle(tmpdir):
assert ad == pickle.loads(pkl)


class UnpickleableArgsEvalModel(EvalModelTemplate):
""" A model that has an attribute that cannot be pickled. """

def __init__(self, foo='bar', pickle_me=(lambda x: x + 1), **kwargs):
super().__init__(**kwargs)
assert not is_picklable(pickle_me)
self.save_hyperparameters()


def test_hparams_pickle_warning(tmpdir):
model = UnpickleableArgsEvalModel()
trainer = Trainer(default_root_dir=tmpdir, max_steps=1)
with pytest.warns(UserWarning, match="attribute 'pickle_me' removed from hparams because it cannot be pickled"):
trainer.fit(model)
assert 'pickle_me' not in model.hparams


def test_hparams_save_yaml(tmpdir):
hparams = dict(batch_size=32, learning_rate=0.001, data_root='./any/path/here',
nasted=dict(any_num=123, anystr='abcd'))
Expand Down

0 comments on commit d5254ff

Please sign in to comment.