Skip to content

Commit ae98031

Browse files
Add logging messages to notify when FitLoop stopping conditions are met (#9749)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
1 parent 4f53e71 commit ae98031

File tree

4 files changed

+70
-11
lines changed

4 files changed

+70
-11
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9393
- Added `XLAEnvironment` cluster environment plugin ([#11330](https://github.com/PyTorchLightning/pytorch-lightning/pull/11330))
9494

9595

96+
- Added logging messages to notify when `FitLoop` stopping conditions are met ([#9749](https://github.com/PyTorchLightning/pytorch-lightning/pull/9749))
97+
98+
9699
- Added support for calling unknown methods with `DummyLogger` ([#13224](https://github.com/PyTorchLightning/pytorch-lightning/pull/13224)
97100

98101

src/pytorch_lightning/loops/fit_loop.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
InterBatchParallelDataFetcher,
3434
)
3535
from pytorch_lightning.utilities.model_helpers import is_overridden
36-
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
36+
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn
3737
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
3838

3939
log = logging.getLogger(__name__)
@@ -150,31 +150,40 @@ def _results(self) -> _ResultCollection:
150150
@property
151151
def done(self) -> bool:
152152
"""Evaluates when to leave the loop."""
153+
if self.trainer.num_training_batches == 0:
154+
rank_zero_info("`Trainer.fit` stopped: No training batches.")
155+
return True
156+
153157
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
154158
stop_steps = _is_max_limit_reached(self.epoch_loop.global_step, self.max_steps)
159+
if stop_steps:
160+
rank_zero_info(f"`Trainer.fit` stopped: `max_steps={self.max_steps!r}` reached.")
161+
return True
162+
155163
# `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
156164
# we use it here because the checkpoint data won't have `completed` increased yet
157165
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
158166
if stop_epochs:
159167
# in case they are not equal, override so `trainer.current_epoch` has the expected value
160168
self.epoch_progress.current.completed = self.epoch_progress.current.processed
169+
rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.")
170+
return True
161171

162-
should_stop = False
163172
if self.trainer.should_stop:
164173
# early stopping
165174
met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
166175
met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
167176
if met_min_epochs and met_min_steps:
168-
should_stop = True
177+
self.trainer.should_stop = True
178+
rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.")
179+
return True
169180
else:
170-
log.info(
171-
"Trainer was signaled to stop but required minimum epochs"
172-
f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has"
173-
" not been met. Training will continue..."
181+
rank_zero_info(
182+
f"Trainer was signaled to stop but the required `min_epochs={self.min_epochs!r}` or"
183+
f" `min_steps={self.min_steps!r}` has not been met. Training will continue..."
174184
)
175-
self.trainer.should_stop = should_stop
176-
177-
return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0
185+
self.trainer.should_stop = False
186+
return False
178187

179188
@property
180189
def skip(self) -> bool:

tests/tests_pytorch/loops/test_training_loop.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import logging
15+
from unittest.mock import Mock
16+
1417
import pytest
1518
import torch
1619

1720
from pytorch_lightning import seed_everything, Trainer
1821
from pytorch_lightning.demos.boring_classes import BoringModel
22+
from pytorch_lightning.loops import FitLoop
1923

2024

2125
def test_outputs_format(tmpdir):
@@ -136,6 +140,49 @@ def validation_step(self, *args):
136140
assert model.validation_called_at == (0, 5)
137141

138142

143+
def test_fit_loop_done_log_messages(caplog):
144+
fit_loop = FitLoop()
145+
trainer = Mock(spec=Trainer)
146+
fit_loop.trainer = trainer
147+
148+
trainer.should_stop = False
149+
trainer.num_training_batches = 5
150+
assert not fit_loop.done
151+
assert not caplog.messages
152+
153+
trainer.num_training_batches = 0
154+
assert fit_loop.done
155+
assert "No training batches" in caplog.text
156+
caplog.clear()
157+
trainer.num_training_batches = 5
158+
159+
epoch_loop = Mock()
160+
epoch_loop.global_step = 10
161+
fit_loop.connect(epoch_loop=epoch_loop)
162+
fit_loop.max_steps = 10
163+
assert fit_loop.done
164+
assert "max_steps=10` reached" in caplog.text
165+
caplog.clear()
166+
fit_loop.max_steps = 20
167+
168+
fit_loop.epoch_progress.current.processed = 3
169+
fit_loop.max_epochs = 3
170+
trainer.should_stop = True
171+
assert fit_loop.done
172+
assert "max_epochs=3` reached" in caplog.text
173+
caplog.clear()
174+
fit_loop.max_epochs = 5
175+
176+
fit_loop.epoch_loop.min_steps = 0
177+
with caplog.at_level(level=logging.DEBUG, logger="pytorch_lightning.utilities.rank_zero"):
178+
assert fit_loop.done
179+
assert "should_stop` was set" in caplog.text
180+
181+
fit_loop.epoch_loop.min_steps = 100
182+
assert not fit_loop.done
183+
assert "was signaled to stop but" in caplog.text
184+
185+
139186
def test_warning_valid_train_step_end(tmpdir):
140187
class ValidTrainStepEndModel(BoringModel):
141188
def training_step(self, batch, batch_idx):

tests/tests_pytorch/trainer/test_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ def training_step(self, batch, batch_idx):
616616
with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"):
617617
trainer.fit(model)
618618

619-
message = f"minimum epochs ({min_epochs}) or minimum steps (None) has not been met. Training will continue"
619+
message = f"min_epochs={min_epochs}` or `min_steps=None` has not been met. Training will continue"
620620
num_messages = sum(1 for record in caplog.records if message in record.message)
621621
assert num_messages == min_epochs - 2
622622
assert model.training_step_invoked == min_epochs * 2

0 commit comments

Comments
 (0)