Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logging messages to notify when FitLoop stopping conditions are met #9749

Merged
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `XLAEnvironment` cluster environment plugin ([#11330](https://github.com/PyTorchLightning/pytorch-lightning/pull/11330))


- Added logging messages to notify when `FitLoop` stopping conditions are met ([#9749](https://github.com/PyTorchLightning/pytorch-lightning/pull/9749))


- Added support for calling unknown methods with `DummyLogger` ([#13224](https://github.com/PyTorchLightning/pytorch-lightning/pull/13224)


Expand Down
29 changes: 19 additions & 10 deletions src/pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
InterBatchParallelDataFetcher,
)
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_warn, rank_zero_info, rank_zero_debug
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -150,31 +150,40 @@ def _results(self) -> _ResultCollection:
@property
def done(self) -> bool:
"""Evaluates when to leave the loop."""
if self.trainer.num_training_batches == 0:
rank_zero_info(f"`Trainer.fit` stopped: No training batches.")
return True

# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
stop_steps = _is_max_limit_reached(self.epoch_loop.global_step, self.max_steps)
if stop_steps:
rank_zero_info(f"`Trainer.fit` stopped: `max_steps={self.max_steps!r}` reached.")
return True

# `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
# we use it here because the checkpoint data won't have `completed` increased yet
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
if stop_epochs:
# in case they are not equal, override so `trainer.current_epoch` has the expected value
self.epoch_progress.current.completed = self.epoch_progress.current.processed
rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.")
return True

should_stop = False
if self.trainer.should_stop:
# early stopping
met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
if met_min_epochs and met_min_steps:
should_stop = True
self.trainer.should_stop = True
rank_zero_debug(f"`Trainer.fit` stopped: `should_stop` was set.")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return True
else:
log.info(
"Trainer was signaled to stop but required minimum epochs"
f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has"
" not been met. Training will continue..."
rank_zero_info(
f"Trainer was signaled to stop but the required `min_epochs={self.min_epochs!r}` or"
f" `min_steps={self.min_steps!r}` has not been met. Training will continue..."
)
self.trainer.should_stop = should_stop

return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0
self.trainer.should_stop = False
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return False

@property
def skip(self) -> bool:
Expand Down
47 changes: 47 additions & 0 deletions tests/tests_pytorch/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from unittest.mock import Mock

import pytest
import torch

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.loops import FitLoop


def test_outputs_format(tmpdir):
Expand Down Expand Up @@ -136,6 +140,49 @@ def validation_step(self, *args):
assert model.validation_called_at == (0, 5)


def test_fit_loop_done_log_messages(caplog):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
fit_loop = FitLoop()
trainer = Mock(spec=Trainer)
fit_loop.trainer = trainer

trainer.should_stop = False
trainer.num_training_batches = 5
assert not fit_loop.done
assert not caplog.messages

trainer.num_training_batches = 0
assert fit_loop.done
assert 'No training batches' in caplog.text
caplog.clear()
trainer.num_training_batches = 5

epoch_loop = Mock()
epoch_loop.global_step = 10
fit_loop.connect(epoch_loop=epoch_loop)
fit_loop.max_steps = 10
assert fit_loop.done
assert 'max_steps=10` reached' in caplog.text
caplog.clear()
fit_loop.max_steps = 20

fit_loop.epoch_progress.current.processed = 3
fit_loop.max_epochs = 3
trainer.should_stop = True
assert fit_loop.done
assert 'max_epochs=3` reached' in caplog.text
caplog.clear()
fit_loop.max_epochs = 5

fit_loop.epoch_loop.min_steps = 0
with caplog.at_level(level=logging.DEBUG, logger="pytorch_lightning.utilities.rank_zero"):
assert fit_loop.done
assert 'should_stop` was set' in caplog.text

fit_loop.epoch_loop.min_steps = 100
assert not fit_loop.done
assert 'was signaled to stop but' in caplog.text


def test_warning_valid_train_step_end(tmpdir):
class ValidTrainStepEndModel(BoringModel):
def training_step(self, batch, batch_idx):
Expand Down