Skip to content

Commit 6fd7082

Browse files
committed
update logic and chlog
1 parent d28e797 commit 6fd7082

File tree

9 files changed

+26
-24
lines changed

9 files changed

+26
-24
lines changed

docs/source-pytorch/common/trainer.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,7 +1480,7 @@ Can specify as float or int.
14801480

14811481
- pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch.
14821482
- pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the number of training
1483-
batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches across ``trainer.fit``.
1483+
batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches across epochs or iteration-based training.
14841484

14851485
.. testcode::
14861486

@@ -1493,7 +1493,7 @@ Can specify as float or int.
14931493
# check validation set every 1000 training batches in the current epoch
14941494
trainer = Trainer(val_check_interval=1000)
14951495

1496-
# check validation set every 1000 training batches across complete training
1496+
# check validation set every 1000 training batches across complete epochs or during iteration-based training
14971497
# use this when using iterableDataset and your dataset has no length
14981498
# (ie: production cases with streaming data)
14991499
trainer = Trainer(val_check_interval=1000, check_val_every_n_epoch=None)

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
144144
- Enabled using any Sampler in distributed environment in Lite ([#13646](https://github.com/PyTorchLightning/pytorch-lightning/pull/13646))
145145

146146

147-
-
147+
- Updated `val_check_interval`(int) to consider total train batches processed instead of `_batches_that_stepped` for validation check during training ([#12832](https://github.com/Lightning-AI/lightning/pull/12832)
148148

149149

150150
### Deprecated

src/pytorch_lightning/callbacks/progress/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def total_val_batches(self) -> Union[int, float]:
173173
return sum(self.trainer.num_val_batches) if self._trainer.fit_loop.epoch_loop._should_check_val_epoch() else 0
174174

175175
@property
176-
def total_main_progress_bar_count_current_epoch(self) -> Union[int, float]:
176+
def total_batches_current_epoch(self) -> Union[int, float]:
177177
total_train_batches = self.total_train_batches
178178
total_val_batches = self.total_val_batches
179179
assert self._trainer is not None
@@ -182,9 +182,9 @@ def total_main_progress_bar_count_current_epoch(self) -> Union[int, float]:
182182
# val can be checked multiple times per epoch
183183
val_check_batch = self.trainer.val_check_batch
184184
if self.trainer.check_val_every_n_epoch is None:
185-
batches_that_stepped = self.trainer.fit_loop.epoch_loop._batches_that_stepped
186-
val_checks_per_epoch = ((batches_that_stepped + total_train_batches) // val_check_batch) - (
187-
batches_that_stepped // val_check_batch
185+
train_batches_processed = self.trainer.fit_loop.total_batch_idx + 1
186+
val_checks_per_epoch = ((train_batches_processed + total_train_batches) // val_check_batch) - (
187+
train_batches_processed // val_check_batch
188188
)
189189
else:
190190
val_checks_per_epoch = total_train_batches // val_check_batch

src/pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def on_sanity_check_end(self, trainer, pl_module):
324324
self.refresh()
325325

326326
def on_train_epoch_start(self, trainer, pl_module):
327-
total_batches = self.total_main_progress_bar_count_current_epoch
327+
total_batches = self.total_batches_current_epoch
328328
train_description = self._get_train_description(trainer.current_epoch)
329329

330330
if self.main_progress_bar_id is not None and self._leave:

src/pytorch_lightning/callbacks/progress/tqdm_progress.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def on_train_start(self, *_: Any) -> None:
252252
self.main_progress_bar = self.init_train_tqdm()
253253

254254
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
255-
total_batches = self.total_main_progress_bar_count_current_epoch
255+
total_batches = self.total_batches_current_epoch
256256
self.main_progress_bar.reset(convert_inf(total_batches))
257257
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
258258

src/pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
163163
Raises:
164164
StopIteration: When the epoch is canceled by the user returning -1
165165
"""
166-
if self.restarting and self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch):
166+
if self.restarting and self._should_check_val_fx():
167167
# skip training and run validation in `on_advance_end`
168168
return
169169
# we are going to train first so the val loop does not need to restart
@@ -235,7 +235,7 @@ def on_advance_end(self) -> None:
235235
# -----------------------------------------
236236
# VALIDATE IF NEEDED
237237
# -----------------------------------------
238-
should_check_val = self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch)
238+
should_check_val = self._should_check_val_fx()
239239
if should_check_val:
240240
self.trainer.validating = True
241241
self._run_validation()
@@ -496,13 +496,14 @@ def _should_check_val_epoch(self) -> bool:
496496
or (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
497497
)
498498

499-
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
499+
def _should_check_val_fx(self) -> bool:
500500
"""Decide if we should run validation."""
501501
if not self._should_check_val_epoch():
502502
return False
503503

504504
# val_check_batch is inf for iterable datasets with no length defined
505505
is_infinite_dataset = self.trainer.val_check_batch == float("inf")
506+
is_last_batch = self.batch_progress.is_last_batch
506507
if is_last_batch and is_infinite_dataset:
507508
return True
508509

@@ -512,13 +513,11 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
512513
# TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch
513514
is_val_check_batch = is_last_batch
514515
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
515-
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
516+
is_val_check_batch = (self.batch_idx + 1) % self.trainer.limit_train_batches == 0
516517
elif self.trainer.val_check_batch != float("inf"):
517518
# if `check_val_every_n_epoch is `None`, run a validation loop every n training batches
518519
# else condition it based on the batch_idx of the current epoch
519-
current_iteration = (
520-
self._batches_that_stepped if self.trainer.check_val_every_n_epoch is None else batch_idx
521-
)
520+
current_iteration = self.total_batch_idx if self.trainer.check_val_every_n_epoch is None else self.batch_idx
522521
is_val_check_batch = (current_iteration + 1) % self.trainer.val_check_batch == 0
523522

524523
return is_val_check_batch

src/pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def __init__(
395395
after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
396396
batches. An ``int`` value can only be higher than the number of training batches when
397397
``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
398-
across ``trainer.fit``.
398+
across epochs or during iteration-based training.
399399
Default: ``1.0``.
400400
401401
enable_model_summary: Whether to enable model summarization by default.

tests/tests_pytorch/callbacks/progress/test_base_progress.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,6 @@ def test_main_progress_bar_with_val_check_interval_int():
2929
trainer.reset_val_dataloader()
3030
expected = [15, 25, 25, 15]
3131

32-
for expected_count in expected:
33-
assert trainer.progress_bar_callback.total_main_progress_bar_count_current_epoch == expected_count
34-
trainer.fit_loop.epoch_loop._batches_that_stepped += train_batches
32+
for count in expected:
33+
assert trainer.progress_bar_callback.total_batches_current_epoch == count
34+
trainer.fit_loop.epoch_loop.batch_progress.total.ready += train_batches

tests/tests_pytorch/trainer/flags/test_val_check_interval.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,20 @@ def test_val_check_interval_info_message(caplog, value):
6363

6464

6565
@pytest.mark.parametrize("use_infinite_dataset", [True, False])
66-
def test_validation_check_interval_exceed_data_length_correct(tmpdir, use_infinite_dataset):
66+
@pytest.mark.parametrize("accumulate_grad_batches", [1, 2])
67+
def test_validation_check_interval_exceed_data_length_correct(tmpdir, use_infinite_dataset, accumulate_grad_batches):
6768
data_samples_train = 4
6869
max_epochs = 3
6970
max_steps = data_samples_train * max_epochs
71+
max_opt_steps = max_steps // accumulate_grad_batches
7072

7173
class TestModel(BoringModel):
7274
def __init__(self):
7375
super().__init__()
7476
self.validation_called_at_step = set()
7577

7678
def validation_step(self, *args):
77-
self.validation_called_at_step.add(self.global_step)
79+
self.validation_called_at_step.add(self.trainer.fit_loop.total_batch_idx + 1)
7880
return super().validation_step(*args)
7981

8082
def train_dataloader(self):
@@ -89,16 +91,17 @@ def train_dataloader(self):
8991
trainer = Trainer(
9092
default_root_dir=tmpdir,
9193
limit_val_batches=1,
92-
max_steps=max_steps,
94+
max_steps=max_opt_steps,
9395
val_check_interval=3,
9496
check_val_every_n_epoch=None,
9597
num_sanity_val_steps=0,
98+
accumulate_grad_batches=accumulate_grad_batches,
9699
)
97100

98101
trainer.fit(model)
99102

100103
assert trainer.current_epoch == 1 if use_infinite_dataset else max_epochs
101-
assert trainer.global_step == max_steps
104+
assert trainer.global_step == max_opt_steps
102105
assert sorted(list(model.validation_called_at_step)) == [3, 6, 9, 12]
103106

104107

0 commit comments

Comments
 (0)