Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: Lr finder and hparams compatibility #2821

Merged
merged 10 commits into from
Aug 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed shell injection vulnerability in subprocess call ([#2786](https://github.com/PyTorchLightning/pytorch-lightning/pull/2786))

- Fixed LR finder and `hparams` compatibility ([#2821](https://github.com/PyTorchLightning/pytorch-lightning/pull/2821))

## [0.8.5] - 2020-07-09

### Added
Expand Down
41 changes: 12 additions & 29 deletions pytorch_lightning/trainer/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr


class TrainerLRFinderMixin(ABC):
Expand Down Expand Up @@ -57,24 +57,26 @@ def _run_lr_finder_internally(self, model: LightningModule):
""" Call lr finder internally during Trainer.fit() """
lr_finder = self.lr_find(model)
lr = lr_finder.suggestion()

# TODO: log lr.results to self.logger
if isinstance(self.auto_lr_find, str):
# Try to find requested field, may be nested
if _nested_hasattr(model, self.auto_lr_find):
_nested_setattr(model, self.auto_lr_find, lr)
if lightning_hasattr(model, self.auto_lr_find):
lightning_setattr(model, self.auto_lr_find, lr)
else:
raise MisconfigurationException(
f'`auto_lr_find` was set to {self.auto_lr_find}, however'
' could not find this as a field in `model.hparams`.')
' could not find this as a field in `model` or `model.hparams`.')
else:
if hasattr(model, 'lr'):
model.lr = lr
elif hasattr(model, 'learning_rate'):
model.learning_rate = lr
if lightning_hasattr(model, 'lr'):
lightning_setattr(model, 'lr', lr)
elif lightning_hasattr(model, 'learning_rate'):
lightning_setattr(model, 'learning_rate', lr)
else:
raise MisconfigurationException(
'When auto_lr_find is set to True, expects that hparams'
' either has field `lr` or `learning_rate` that can overridden')
'When auto_lr_find is set to True, expects that `model` or'
' `model.hparams` either has field `lr` or `learning_rate`'
' that can overridden')
log.info(f'Learning rate set to {lr}')

def lr_find(
Expand Down Expand Up @@ -492,22 +494,3 @@ def get_lr(self):
@property
def lr(self):
return self._lr


def _nested_hasattr(obj, path):
parts = path.split(".")
for part in parts:
if hasattr(obj, part):
obj = getattr(obj, part)
else:
return False
else:
return True


def _nested_setattr(obj, path, val):
parts = path.split(".")
for part in parts[:-1]:
if hasattr(obj, part):
obj = getattr(obj, part)
setattr(obj, parts[-1], val)
53 changes: 53 additions & 0 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,56 @@ def __repr__(self):
rows = [tmp_name.format(f'"{n}":', self[n]) for n in sorted(self.keys())]
out = '\n'.join(rows)
return out


def lightning_hasattr(model, attribute):
Borda marked this conversation as resolved.
Show resolved Hide resolved
""" Special hasattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
# Check if attribute in model
if hasattr(model, attribute):
attr = True
# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
attr = attribute in model.hparams
else:
attr = hasattr(model.hparams, attribute)
else:
attr = False

return attr


def lightning_getattr(model, attribute):
""" Special getattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
# Check if attribute in model
if hasattr(model, attribute):
attr = getattr(model, attribute)
# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
attr = model.hparams[attribute]
else:
attr = getattr(model.hparams, attribute)
else:
raise ValueError(f'{attribute} is not stored in the model namespace'
' or the `hparams` namespace/dict.')
return attr


def lightning_setattr(model, attribute, value):
""" Special setattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
# Check if attribute in model
if hasattr(model, attribute):
setattr(model, attribute, value)
# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
model.hparams[attribute] = value
else:
setattr(model.hparams, attribute, value)
else:
raise ValueError(f'{attribute} is not stored in the model namespace'
' or the `hparams` namespace/dict.')
4 changes: 4 additions & 0 deletions tests/base/model_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,7 @@ def configure_optimizers__param_groups(self):
optimizer = optim.Adam(param_groups)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
return [optimizer], [lr_scheduler]

def configure_optimizers__lr_from_hparams(self):
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer
4 changes: 2 additions & 2 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import pytest

from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, seed_everything
from tests.base import EvalModelTemplate
from tests.base.datamodules import TrialMNISTDataModule
from tests.base.develop_utils import reset_seed
Expand Down Expand Up @@ -291,7 +291,7 @@ def test_full_loop_ddp_spawn(tmpdir):
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

reset_seed()
seed_everything(1234)

dm = TrialMNISTDataModule(tmpdir)

Expand Down
30 changes: 24 additions & 6 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,15 @@ def test_trainer_reset_correctly(tmpdir):
f'Attribute {key} was not reset correctly after learning rate finder'


def test_trainer_arg_bool(tmpdir):
@pytest.mark.parametrize('use_hparams', [False, True])
def test_trainer_arg_bool(tmpdir, use_hparams):
""" Test that setting trainer arg to bool works """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
before_lr = hparams.get('learning_rate')
if use_hparams:
del model.learning_rate
model.configure_optimizers = model.configure_optimizers__lr_from_hparams

# logger file to get meta
trainer = Trainer(
Expand All @@ -87,17 +91,27 @@ def test_trainer_arg_bool(tmpdir):
)

trainer.fit(model)
after_lr = model.learning_rate
if use_hparams:
after_lr = model.hparams.learning_rate
else:
after_lr = model.learning_rate

assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'


def test_trainer_arg_str(tmpdir):
@pytest.mark.parametrize('use_hparams', [False, True])
def test_trainer_arg_str(tmpdir, use_hparams):
""" Test that setting trainer arg to string works """
model = EvalModelTemplate()
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
model.my_fancy_lr = 1.0 # update with non-standard field

model.hparams['my_fancy_lr'] = 1.0
before_lr = model.my_fancy_lr
if use_hparams:
del model.my_fancy_lr
model.configure_optimizers = model.configure_optimizers__lr_from_hparams

# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
Expand All @@ -106,7 +120,11 @@ def test_trainer_arg_str(tmpdir):
)

trainer.fit(model)
after_lr = model.my_fancy_lr
if use_hparams:
after_lr = model.hparams.my_fancy_lr
else:
after_lr = model.my_fancy_lr

assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'

Expand Down
61 changes: 61 additions & 0 deletions tests/utilities/parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest

from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr


def _get_test_cases():
class TestHparamsNamespace:
learning_rate = 1

TestHparamsDict = {'learning_rate': 2}

class TestModel1: # test for namespace
learning_rate = 0
model1 = TestModel1()

class TestModel2: # test for hparams namespace
hparams = TestHparamsNamespace()

model2 = TestModel2()

class TestModel3: # test for hparams dict
hparams = TestHparamsDict

model3 = TestModel3()

class TestModel4: # fail case
batch_size = 1

model4 = TestModel4()

return model1, model2, model3, model4


def test_lightning_hasattr(tmpdir):
""" Test that the lightning_hasattr works in all cases"""
model1, model2, model3, model4 = _get_test_cases()
assert lightning_hasattr(model1, 'learning_rate'), \
'lightning_hasattr failed to find namespace variable'
assert lightning_hasattr(model2, 'learning_rate'), \
'lightning_hasattr failed to find hparams namespace variable'
assert lightning_hasattr(model3, 'learning_rate'), \
'lightning_hasattr failed to find hparams dict variable'
assert not lightning_hasattr(model4, 'learning_rate'), \
'lightning_hasattr found variable when it should not'


def test_lightning_getattr(tmpdir):
""" Test that the lightning_getattr works in all cases"""
models = _get_test_cases()
for i, m in enumerate(models[:3]):
value = lightning_getattr(m, 'learning_rate')
assert value == i, 'attribute not correctly extracted'


def test_lightning_setattr(tmpdir):
""" Test that the lightning_setattr works in all cases"""
models = _get_test_cases()
for m in models[:3]:
lightning_setattr(m, 'learning_rate', 10)
assert lightning_getattr(m, 'learning_rate') == 10, \
'attribute not correctly set'