Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow disabling automatic stopping after max_steps or max_epochs #8877

Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
379dda2
Update docstring + logic to disable automatic stopping
EricWiener Aug 12, 2021
d506881
Add a test to check passing negative max_epochs
EricWiener Aug 12, 2021
8d9e707
Updated logic for disabling automatic stopping
EricWiener Aug 27, 2021
3a7cc8a
Updated test cases for max_epochs/max_steps + max_time
EricWiener Aug 27, 2021
a0b8c61
Change brackets to parentheses
EricWiener Aug 27, 2021
a239358
Corrected max_epoch error checking restore_loops
EricWiener Aug 27, 2021
d201464
Validating max_epochs and max_steps
EricWiener Aug 27, 2021
91422f5
Added parameterized tests for max_epochs + max_steps
EricWiener Aug 27, 2021
c6562b4
Shortened timer to 1 sec from 10 sec
EricWiener Aug 27, 2021
f7e8176
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2021
78a1eb8
Fix type error comparing to None
EricWiener Aug 28, 2021
5558917
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2021
6c5439c
Added FitLoop._is_max_limit_enabled
EricWiener Aug 28, 2021
cd7732a
Removed mentioning max_epochs in max_steps docstring
EricWiener Aug 28, 2021
7209976
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2021
794bf31
Remove type signature on `max_value`
EricWiener Aug 29, 2021
945b731
Remove type signature on return value
EricWiener Aug 29, 2021
1b70637
Now checking that max vals are int (vs. not float)
EricWiener Aug 29, 2021
6e53053
Condensed test_timer test
EricWiener Aug 29, 2021
fcada92
Moved details desc of max_epochs/steps to trainer.rst
EricWiener Aug 29, 2021
769c7fc
Shortened max_* desc in trainer.rst
EricWiener Aug 29, 2021
4e6bed4
Update pytorch_lightning/trainer/connectors/checkpoint_connector.py
awaelchli Sep 1, 2021
fe11371
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 1, 2021
e8a3174
Change brackets to paranthesis
EricWiener Sep 2, 2021
1e13018
Update docs/source/common/trainer.rst
EricWiener Sep 2, 2021
db53b8c
Update pytorch_lightning/trainer/trainer.py
EricWiener Sep 2, 2021
b805ff5
No longer checking if max_epochs/steps is an int
EricWiener Sep 2, 2021
4c5f5d8
Fixed test_trainer_max_steps_and_epochs_fit_loop_done
EricWiener Sep 3, 2021
05ed3a3
Fix test_timer.py::test_trainer_flag
EricWiener Sep 3, 2021
d21b009
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 3, 2021
4594284
Fixed test_trainer_max_steps_and_epochs_validation
EricWiener Sep 3, 2021
318fec8
Decrease global step in tests/trainer/test_trainer.py
EricWiener Sep 3, 2021
0c322ea
Change EvalModelTemplate to BoringModel
EricWiener Sep 3, 2021
145c27b
Moved max_* validation into constructors
EricWiener Sep 3, 2021
3b3f29f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 3, 2021
38b4436
Fix pre-commit
carmocca Sep 4, 2021
1a17f87
Keep TODO at the top
carmocca Sep 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,9 @@ Stop training once this number of epochs is reached
# default used by the Trainer
trainer = Trainer(max_epochs=1000)

If both ``max_epochs`` and ``max_steps`` aren't specified, ``max_epochs`` will default to ``1000``.
To disable this default, set ``max_epochs = -1``.
EricWiener marked this conversation as resolved.
Show resolved Hide resolved

min_epochs
^^^^^^^^^^

Expand Down Expand Up @@ -947,6 +950,9 @@ Training will stop if max_steps or max_epochs have reached (earliest).
# Stop after 100 steps
trainer = Trainer(max_steps=100)

If ``max_steps`` is not specified, ``max_epochs`` will be used instead (and ``max_epochs`` defaults to
``1000`` if ``max_epochs`` is not specified). To disable this default, set ``max_steps = -1``.

min_steps
^^^^^^^^^

Expand Down
17 changes: 15 additions & 2 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,19 @@ def _results(self) -> ResultCollection:
return self.epoch_loop.val_loop._results
raise RuntimeError("`FitLoop._results` property isn't defined. Accessed outside of scope")

@staticmethod
def _is_max_limit_enabled(max_value: Optional[int]) -> bool:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Checks whether the max_value is enabled. This can
be used for checking whether max_epochs or max_steps is enabled.

Args:
max_value: the value to check

Returns:
whether the limit for this value should be enabled
"""
return max_value not in (None, -1)

@property
def done(self) -> bool:
"""Evaluates when to leave the loop.
Expand All @@ -132,8 +145,8 @@ def done(self) -> bool:
or if the maximum number of steps or epochs is reached.
"""
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
stop_steps = self.max_steps is not None and self.global_step >= self.max_steps
stop_epochs = self.max_epochs is not None and self.current_epoch >= self.max_epochs
stop_steps = FitLoop._is_max_limit_enabled(self.max_steps) and self.global_step >= self.max_steps
stop_epochs = FitLoop._is_max_limit_enabled(self.max_epochs) and self.current_epoch >= self.max_epochs
carmocca marked this conversation as resolved.
Show resolved Hide resolved

should_stop = False
if self.trainer.should_stop:
Expand Down
11 changes: 11 additions & 0 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None
"(rather, they are called on every optimization step)."
)

# Allow max_epochs or max_steps to be zero, since this will be handled by fit_loop.done
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if trainer.max_epochs and (trainer.max_epochs < -1 or not isinstance(trainer.max_epochs, int)):
EricWiener marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException(
f"`max_epochs` must be a positive integer or -1. You passed in {trainer.max_epochs}."
)

if trainer.max_steps and (trainer.max_steps < -1 or not isinstance(trainer.max_steps, int)):
raise MisconfigurationException(
f"`max_steps` must be a positive integer or -1. You passed in {trainer.max_steps}."
)

def __verify_eval_loop_configuration(self, model: "pl.LightningModule", stage: str) -> None:
loader_name = f"{stage}_dataloader"
step_name = "validation_step" if stage == "val" else "test_step"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics import Metric

import pytorch_lightning as pl
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -190,7 +191,10 @@ def restore_loops(self) -> None:
self.trainer.fit_loop.current_epoch = self._loaded_checkpoint["epoch"]

# crash if max_epochs is lower then the current epoch from the checkpoint
if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs:
if (
FitLoop._is_max_limit_enabled(self.trainer.max_epochs)
and self.trainer.current_epoch > self.trainer.max_epochs
):
raise MisconfigurationException(
f"You restored a checkpoint with current_epoch={self.trainer.current_epoch},"
f" but you have set Trainer(max_epochs={self.trainer.max_epochs})."
Expand Down
10 changes: 7 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,15 @@ def __init__(
Can be used on CPU, GPU or TPUs.

max_epochs: Stop training once this number of epochs is reached. Disabled by default (None).
If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000.
If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
To disable this default, set ``max_epochs = -1``.
EricWiener marked this conversation as resolved.
Show resolved Hide resolved

min_epochs: Force training for at least these many epochs. Disabled by default (None).
If both min_epochs and min_steps are not specified, defaults to ``min_epochs`` = 1.
If both min_epochs and min_steps are not specified, defaults to ``min_epochs = 1``.

max_steps: Stop training after this number of steps. Disabled by default (None).
max_steps: Stop training after this number of steps. Disabled by default (None). If ``max_steps = None``
EricWiener marked this conversation as resolved.
Show resolved Hide resolved
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To disable this default, set
``max_steps`` to ``-1``.

min_steps: Force training for at least these number of steps. Disabled by default (None).

Expand Down Expand Up @@ -374,6 +377,7 @@ def __init__(
self.slurm_connector = SLURMConnector(self)
self.tuner = Tuner(self)

# max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1).
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
fit_loop = FitLoop(
min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs),
max_epochs=(1000 if (max_epochs is None and max_steps is None and max_time is None) else max_epochs),
Expand Down
8 changes: 6 additions & 2 deletions tests/callbacks/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ def on_fit_start(self):
trainer.fit(TestModel())
assert "callbacks list already contains a Timer" in caplog.text

seconds = 1
trainer = Trainer(max_time=dict(seconds=seconds))
# Make sure max_time still honored even if max_epochs == -1
trainer = Trainer(max_time=dict(seconds=1), max_epochs=-1)
EricWiener marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(SystemExit):
trainer.fit(TestModel())
timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0]
assert timer._duration == 1
assert trainer.max_epochs is None
assert trainer.max_steps is None

Expand Down
50 changes: 50 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,56 @@ def test_trainer_max_steps_and_epochs(tmpdir):
assert trainer.global_step == num_train_samples * trainer.max_epochs
assert trainer.current_epoch == trainer.max_epochs - 1, "Model did not stop at max_epochs"

# if max_steps is positive and max_epochs is negative, use max_steps
trainer_kwargs["max_epochs"] = -1
trainer_kwargs["max_steps"] = 3 * 2 * num_train_samples
trainer = Trainer(**trainer_kwargs)
trainer.fit(model)

assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.global_step == 3 * 2 * num_train_samples
EricWiener marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
"max_epochs,max_steps,incorrect_variable,incorrect_value",
[
(-100, None, "max_epochs", -100),
(1.5, None, "max_epochs", 1.5),
(1, -2, "max_steps", -2),
(1, 0.5, "max_steps", -2),
],
)
def test_trainer_max_steps_and_epochs_validation(max_epochs, max_steps, incorrect_variable, incorrect_value):
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
"""Don't allow max_epochs or max_steps to be less than -1 or a float"""
with pytest.raises(
EricWiener marked this conversation as resolved.
Show resolved Hide resolved
MisconfigurationException,
match=f"`{incorrect_variable}` must be a positive integer or -1. You passed in {incorrect_value}",
):
trainer = Trainer(max_epochs=max_epochs, max_steps=max_steps)


@pytest.mark.parametrize(
"max_epochs,max_steps,is_done",
[
(None, None, False),
(-1, None, False),
(None, -1, False),
(5, -1, False),
(-1, 10, False),
(None, 0, True),
(0, None, True),
(-1, 0, True),
(0, -1, True),
],
)
def test_trainer_max_steps_and_epochs_fit_loop_done(max_epochs, max_steps, is_done):
trainer = Trainer(max_epochs=max_epochs, max_steps=max_steps)

assert trainer.max_epochs == max_epochs
assert trainer.max_steps == max_steps
assert trainer.max_time is None
EricWiener marked this conversation as resolved.
Show resolved Hide resolved
assert trainer.fit_loop.done is is_done


def test_trainer_min_steps_and_epochs(tmpdir):
"""Verify model trains according to specified min steps"""
Expand Down