Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable lr_scheduler.step() in manual optimization #6825

Merged
merged 24 commits into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))


- Disabled `lr_scheduler.step()` in manual optimization ([#6825](https://github.com/PyTorchLightning/pytorch-lightning/pull/6825))
akihironitta marked this conversation as resolved.
Show resolved Hide resolved


### Deprecated

- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
Expand Down
4 changes: 4 additions & 0 deletions docs/source/common/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ to manually manage the optimization process. To do so, do the following:

.. warning:: Before 1.2, ``optimzer.step`` was calling ``optimizer.zero_grad()`` internally. From 1.2, it is left to the users expertise.

.. warning:: Before 1.3, ``lr_scheduler.step`` was called automatically in both manual and automatic optimization. From 1.3, ``lr_scheduler.step`` is disabled in manual optimization so that you can call it at arbitrary intervals. Use ``self.lr_schedulers()`` in LightningModule to access your learning rate schedulers defined in ``LightningModule.configure_optimizers()``.
carmocca marked this conversation as resolved.
Show resolved Hide resolved

.. tip:: To perform ``accumulate_grad_batches`` with one optimizer, you can do as such.

.. tip:: ``self.optimizers()`` will return ``LightningOptimizer`` objects. You can access your own optimizer with ``optimizer.optimizer``. However, if you use your own optimizer to perform a step, Lightning won't be able to support accelerators and precision for you.
Expand Down Expand Up @@ -75,6 +77,7 @@ Here is the same example as above using a ``closure``.

def training_step(self, batch, batch_idx):
opt = self.optimizers()
sch = self.lr_schedulers()

def closure():
# Only zero_grad on the first batch to accumulate gradients
Expand All @@ -87,6 +90,7 @@ Here is the same example as above using a ``closure``.
return loss

opt.step(closure=closure)
sch.step()

.. tip:: Be careful where you call ``zero_grad`` or your model won't converge. It is good pratice to call ``zero_grad`` before ``manual_backward``.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
interval: either 'epoch' or 'step'.
monitor_metrics: dict of possible values to monitor
"""
if not self.trainer.lr_schedulers:
if not self.trainer.lr_schedulers or not self.trainer.train_loop.automatic_optimization:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's update the doc with this change + add a warning if we detect the scheduler was on step or epoch, but never triggered.

Copy link
Contributor Author

@akihironitta akihironitta Apr 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton I updated the docs and added a warning when scheduler keys (e.g. "interval", "frequency", "monitor"), which are invalid in manual optimization, are provided. Would you mind having a look at the changes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Could we add in the doc how to use the scheduler for interval=epoch.

 def training_step(self, batch, batch_idx)

     # do optimization

      if self.trainer.is_last_batch: # is_last_batch doesn't exist for the trainer, could be added.
             self.lr_schedulers[0].step()

Copy link
Contributor Author

@akihironitta akihironitta Apr 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton I added self.trainer.is_last_batch and updated the docs. Could you have a look again?

return

for scheduler_idx, lr_scheduler in enumerate(self.trainer.lr_schedulers):
Expand Down
22 changes: 19 additions & 3 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)

lr_schedulers = self.configure_schedulers(lr_schedulers, monitor=monitor)
is_manual_optimization = not model.automatic_optimization
lr_schedulers = self.configure_schedulers(lr_schedulers, monitor, is_manual_optimization)
_validate_scheduler_optimizer(optimizers, lr_schedulers)

return optimizers, lr_schedulers, optimizer_frequencies
Expand All @@ -98,8 +99,13 @@ def _convert_to_lightning_optimizer(trainer, optimizer):
for opt_idx, opt in enumerate(self.optimizers)
}

def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
# Convert each scheduler into dict structure with relevant information
def configure_schedulers(
self,
schedulers: list,
monitor: Optional[str],
is_manual_optimization: bool,
):
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
"""Convert each scheduler into dict structure with relevant information"""
lr_schedulers = []
default_config = _get_default_scheduler_config()
for scheduler in schedulers:
Expand All @@ -117,6 +123,16 @@ def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
f'The "interval" key in lr scheduler dict must be "step" or "epoch"'
f' but is "{scheduler["interval"]}"'
)
if is_manual_optimization:
invalid_keys = {'interval', 'frequency', 'reduce_on_plateau', 'monitor', 'strict'}
carmocca marked this conversation as resolved.
Show resolved Hide resolved
keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]

if keys_to_warn:
rank_zero_warn(
f'The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored.'
' You need to call `lr_scheduler.step()` manually in manual optimization.',
RuntimeWarning,
)

scheduler['reduce_on_plateau'] = isinstance(
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau
Expand Down
38 changes: 38 additions & 0 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,3 +1147,41 @@ def dis_closure():
@RunIf(min_gpus=2, special=True)
def test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model(tmpdir):
train_manual_optimization(tmpdir, "ddp", model_cls=TestManualOptimizationDDPModelToggleModel)


def test_lr_scheduler_step_not_called(tmpdir):
"""
Test `lr_scheduler.step()` is not called in manual optimization.
"""
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.automatic_optimization = False
akihironitta marked this conversation as resolved.
Show resolved Hide resolved

def training_step(self, batch, batch_idx):
opt = self.optimizers()

output = self(batch)
loss = self.loss(batch, output)

opt.zero_grad()
self.manual_backward(loss)
opt.step()

model = TestModel()
model.training_step_end = None
model.training_epoch_end = None

trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
fast_dev_run=2,
)

with patch("torch.optim.lr_scheduler.StepLR.step") as lr_step:
trainer.fit(model)

# If a lr scheduler inherits `torch.optim.lr_scheduler._LRScheduler`,
# `.step()` is called once during its instantiation.
# Thus, the call count should be 1.
assert lr_step.call_count == 1
22 changes: 22 additions & 0 deletions tests/trainer/optimization/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,25 @@ def configure_optimizers(self):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.raises(MisconfigurationException, match="attatched with an optimizer that wasn't returned"):
trainer.fit(model)


def test_warn_invalid_scheduler_key_in_manual_optimization(tmpdir):
"""
Test warning when invalid scheduler keys are provided in manual optimization.
"""

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.automatic_optimization = False

def configure_optimizers(self):
opt = torch.optim.SGD(self.layer.parameters(), lr=0.1)
sch = torch.optim.lr_scheduler.StepLR(opt, step_size=1)
return [opt], [{"scheduler": sch, "interval": "epoch"}]

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.warns(RuntimeWarning, match='the keys will be ignored'):
trainer.fit(model)