Skip to content

Commit

Permalink
Avoid warning about logging interval for fast dev run (#18550)
Browse files Browse the repository at this point in the history
(cherry picked from commit 670b490)
  • Loading branch information
awaelchli authored and lantiga committed Sep 14, 2023
1 parent 72c097e commit 54085d1
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed visual glitch with the TQDM progress bar leaving the validation bar incomplete before switching back to the training display ([#18503](https://github.com/Lightning-AI/lightning/pull/18503))


- Fixed false positive warning about logging interval when running with `Trainer(fast_dev_run=True)` ([#18550](https://github.com/Lightning-AI/lightning/pull/18550))


## [2.0.7] - 2023-08-14

### Added
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def setup_data(self) -> None:
trainer.val_check_batch = int(self.max_batches * trainer.val_check_interval)
trainer.val_check_batch = max(1, trainer.val_check_batch)

if trainer.loggers and self.max_batches < trainer.log_every_n_steps:
if trainer.loggers and self.max_batches < trainer.log_every_n_steps and not trainer.fast_dev_run:
rank_zero_warn(
f"The number of training batches ({self.max_batches}) is smaller than the logging interval"
f" Trainer(log_every_n_steps={trainer.log_every_n_steps}). Set a lower value for log_every_n_steps if"
Expand Down
5 changes: 5 additions & 0 deletions tests/tests_pytorch/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy
import pytest
import torch
from lightning_utilities.test.warning import no_warning_call
from torch.utils.data import RandomSampler
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset, IterableDataset
Expand Down Expand Up @@ -693,6 +694,10 @@ def test_warning_with_small_dataloader_and_logging_interval(tmpdir):
)
trainer.fit(model)

with no_warning_call(UserWarning, match="The number of training batches"):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, log_every_n_steps=2)
trainer.fit(model)


def test_warning_with_iterable_dataset_and_len(tmpdir):
"""Tests that a warning message is shown when an IterableDataset defines `__len__`."""
Expand Down

0 comments on commit 54085d1

Please sign in to comment.