Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable purely iteration-based training #5726

Merged
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516))


- Extended support for purely iteration-based training ([#5726](https://github.com/PyTorchLightning/pytorch-lightning/pull/5726))


- Made `LightningModule.global_rank`, `LightningModule.local_rank` and `LightningModule.logger` read-only properties ([#5730](https://github.com/PyTorchLightning/pytorch-lightning/pull/5730))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def restore_training_state(self, checkpoint):
self.trainer.current_epoch = checkpoint['epoch']

# crash if max_epochs is lower then the current epoch from the checkpoint
if self.trainer.current_epoch > self.trainer.max_epochs:
if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs:
m = f"""
you restored a checkpoint with current_epoch={self.trainer.current_epoch}
but the Trainer(max_epochs={self.trainer.max_epochs})
Expand Down
16 changes: 10 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Trainer to automate the training."""

import warnings
from itertools import count
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Union

Expand Down Expand Up @@ -101,8 +102,8 @@ def __init__(
check_val_every_n_epoch: int = 1,
fast_dev_run: Union[int, bool] = False,
accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
max_epochs: int = 1000,
min_epochs: int = 1,
max_epochs: Optional[int] = None,
min_epochs: Optional[int] = None,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
max_steps: Optional[int] = None,
min_steps: Optional[int] = None,
limit_train_batches: Union[int, float] = 1.0,
Expand Down Expand Up @@ -231,9 +232,11 @@ def __init__(

precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs.

max_epochs: Stop training once this number of epochs is reached.
max_epochs: Stop training once this number of epochs is reached. Disabled by default (None).
If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000.

min_epochs: Force training for at least these many epochs
min_epochs: Force training for at least these many epochs. Disabled by default (None).
If both min_epochs and min_steps are not specified, defaults to ``min_epochs`` = 1.

max_steps: Stop training after this number of steps. Disabled by default (None).

Expand Down Expand Up @@ -586,7 +589,8 @@ def train(self):
if self.train_loop.should_skip_training():
return
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):
epochs = range(self.current_epoch, self.max_epochs) if self.max_epochs else count(self.current_epoch)
for epoch in epochs:

# hook
self.train_loop.on_train_epoch_start(epoch)
Expand All @@ -599,7 +603,7 @@ def train(self):
return

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_epochs = epoch >= self.min_epochs - 1 if self.min_epochs else True
Borda marked this conversation as resolved.
Show resolved Hide resolved
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

if self.should_stop:
Expand Down
10 changes: 7 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ def on_trainer_init(
self.trainer.train_dataloader = None
self.automatic_optimization = automatic_optimization

self.trainer.max_epochs = max_epochs
self.trainer.min_epochs = min_epochs
# If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000
self.trainer.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# If neither max_epochs or max_steps is set, then use existing default of min_epochs = 1
self.trainer.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs
self.trainer.max_steps = max_steps
self.trainer.min_steps = min_steps

Expand All @@ -93,7 +95,9 @@ def num_optimizers(self):
return num_optimizers

def should_skip_training(self):
return self.trainer.current_epoch >= self.trainer.max_epochs or self.trainer.num_training_batches == 0
return (
self.trainer.max_epochs is not None and self.trainer.current_epoch >= self.trainer.max_epochs
) or self.trainer.num_training_batches == 0
Borda marked this conversation as resolved.
Show resolved Hide resolved

def on_train_start(self):
# clear cache before training
Expand Down
43 changes: 43 additions & 0 deletions tests/trainer/flags/test_min_max_epochs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest

from pytorch_lightning import Trainer
from tests.helpers import BoringModel


# @pytest.mark.parametrize("min_epochs", [None, 2])
# @pytest.mark.parametrize("max_epochs", [None, 3])
# @pytest.mark.parametrize("min_steps", [None, 20])
# @pytest.mark.parametrize("max_steps", [None, 100])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all of these 16 configurations valid?
Did you try what happens if we choose max_epochs = 1 and max_steps = 5 but the epoch only has 3 batches?
The number that leads to less training steps should terminate training, right? This could be a separate test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if both are specified, i think should exit out whenever the first condition is met. so for the example you provided, if the epoch has only 3 batches, i agree we should return after the epoch completes and before moving onto the remaining 2 steps. the same goes for min_steps

i think users can already exercise this functionality now, so i'll look to see if there are already test cases for it. if not i'll add them here

Borda marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize(
["min_epochs", "max_epochs", "min_steps", "max_steps"],
[
pytest.param(None, 5, None, None),
pytest.param(None, None, None, 100),
pytest.param(None, 5, None, 100),
pytest.param(None, None, 10, 100),
pytest.param(1, 5, None, None),
pytest.param(1, None, None, 100),
pytest.param(None, 5, 10, None),
carmocca marked this conversation as resolved.
Show resolved Hide resolved
carmocca marked this conversation as resolved.
Show resolved Hide resolved
],
)
def test_min_max_steps_epochs(tmpdir, min_epochs, max_epochs, min_steps, max_steps):
"""
Tests that max_steps can be used without max_epochs
"""
model = BoringModel()

trainer = Trainer(
default_root_dir=tmpdir,
min_epochs=min_epochs,
max_epochs=max_epochs,
min_steps=min_steps,
max_steps=max_steps,
weights_summary=None,
)

result = trainer.fit(model)
assert result == 1, "Training did not complete"

# check training stopped at max_epochs or max_steps
if trainer.max_steps and not trainer.max_epochs:
assert trainer.global_step == trainer.max_steps
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir):
model = EvalModelTemplate(**hparams)
model.hparams = hparams
# now we have model.batch_size and model.hparams.batch_size
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, auto_scale_batch_size=True)
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, max_epochs=1000, auto_scale_batch_size=True)
expected_message = "Field `model.batch_size` and `model.hparams.batch_size` are mutually exclusive!"
with pytest.warns(UserWarning, match=expected_message):
trainer.tune(model)
Expand Down