Skip to content

Commit

Permalink
Disable lr_scheduler.step() in manual optimization (#6825)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
  • Loading branch information
3 people authored Apr 20, 2021
1 parent 14e6b46 commit 0302b8b
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 5 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `model` parameter to precision plugins' `clip_gradients` signature ([#6764](https://github.com/PyTorchLightning/pytorch-lightning/pull/6764))


- Added `is_last_batch` attribute to `Trainer` ([#6825](https://github.com/PyTorchLightning/pytorch-lightning/pull/6825))


- Added `LightningModule.lr_schedulers()` for manual optimization ([#6567](https://github.com/PyTorchLightning/pytorch-lightning/pull/6567))


Expand Down Expand Up @@ -129,6 +132,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))


- Disabled `lr_scheduler.step()` in manual optimization ([#6825](https://github.com/PyTorchLightning/pytorch-lightning/pull/6825))


- Changed warnings and recommendations for dataloaders in `ddp_spawn` ([#6762](https://github.com/PyTorchLightning/pytorch-lightning/pull/6762/))


Expand Down
2 changes: 1 addition & 1 deletion docs/source/common/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ To manually optimize, do the following:
* ``optimizer.step()`` to update your model parameters

Here is a minimal example of manual optimization.

.. testcode:: python

from pytorch_lightning import LightningModule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
interval: either 'epoch' or 'step'.
monitor_metrics: dict of possible values to monitor
"""
if not self.trainer.lr_schedulers:
if not self.trainer.lr_schedulers or not self.trainer.train_loop.automatic_optimization:
return

for scheduler_idx, lr_scheduler in enumerate(self.trainer.lr_schedulers):
Expand Down
22 changes: 19 additions & 3 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)

lr_schedulers = self.configure_schedulers(lr_schedulers, monitor=monitor)
is_manual_optimization = not self.train_loop.automatic_optimization
lr_schedulers = self.configure_schedulers(lr_schedulers, monitor, is_manual_optimization)
_validate_scheduler_optimizer(optimizers, lr_schedulers)

return optimizers, lr_schedulers, optimizer_frequencies
Expand All @@ -98,8 +99,13 @@ def _convert_to_lightning_optimizer(trainer, optimizer):
for opt_idx, opt in enumerate(self.optimizers)
}

def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
# Convert each scheduler into dict structure with relevant information
def configure_schedulers(
self,
schedulers: list,
monitor: Optional[str],
is_manual_optimization: bool,
) -> List[Dict[str, Any]]:
"""Convert each scheduler into dict structure with relevant information"""
lr_schedulers = []
default_config = _get_default_scheduler_config()
for scheduler in schedulers:
Expand All @@ -117,6 +123,16 @@ def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
f'The "interval" key in lr scheduler dict must be "step" or "epoch"'
f' but is "{scheduler["interval"]}"'
)
if is_manual_optimization:
invalid_keys = {'interval', 'frequency', 'reduce_on_plateau', 'monitor', 'strict'}
keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]

if keys_to_warn:
rank_zero_warn(
f'The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored.'
' You need to call `lr_scheduler.step()` manually in manual optimization.',
RuntimeWarning,
)

scheduler['reduce_on_plateau'] = isinstance(
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def run_training_epoch(self):
for batch_idx, (batch, is_last_batch) in train_dataloader:

self.trainer.batch_idx = batch_idx
self.trainer.is_last_batch = is_last_batch

# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
Expand Down
38 changes: 38 additions & 0 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,3 +1170,41 @@ def configure_optimizers(self):
)

trainer.fit(model)


def test_lr_scheduler_step_not_called(tmpdir):
"""
Test `lr_scheduler.step()` is not called in manual optimization.
"""
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx):
opt = self.optimizers()

output = self(batch)
loss = self.loss(batch, output)

opt.zero_grad()
self.manual_backward(loss)
opt.step()

model = TestModel()
model.training_step_end = None
model.training_epoch_end = None

trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
fast_dev_run=2,
)

with patch("torch.optim.lr_scheduler.StepLR.step") as lr_step:
trainer.fit(model)

# If a lr scheduler inherits `torch.optim.lr_scheduler._LRScheduler`,
# `.step()` is called once during its instantiation.
# Thus, the call count should be 1, not 0.
assert lr_step.call_count == 1
22 changes: 22 additions & 0 deletions tests/trainer/optimization/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,25 @@ def configure_optimizers(self):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.raises(MisconfigurationException, match="attatched with an optimizer that wasn't returned"):
trainer.fit(model)


def test_warn_invalid_scheduler_key_in_manual_optimization(tmpdir):
"""
Test warning when invalid scheduler keys are provided in manual optimization.
"""

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.automatic_optimization = False

def configure_optimizers(self):
opt = torch.optim.SGD(self.layer.parameters(), lr=0.1)
sch = torch.optim.lr_scheduler.StepLR(opt, step_size=1)
return [opt], [{"scheduler": sch, "interval": "epoch"}]

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.warns(RuntimeWarning, match='the keys will be ignored'):
trainer.fit(model)

0 comments on commit 0302b8b

Please sign in to comment.