Skip to content

Commit e631a66

Browse files
authored
Update TQDM progress bar tracking with multiple dataloaders (#11657)
1 parent 28dac0c commit e631a66

File tree

6 files changed

+337
-188
lines changed

6 files changed

+337
-188
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
350350
- Changed `parallel_devices` property in `ParallelStrategy` to be lazy initialized ([#11572](https://github.com/PyTorchLightning/pytorch-lightning/pull/11572))
351351

352352

353+
- Updated `TQDMProgressBar` to run a separate progress bar for each eval dataloader ([#11657](https://github.com/PyTorchLightning/pytorch-lightning/pull/11657))
354+
355+
353356
- Sorted `SimpleProfiler(extended=False)` summary based on mean duration for each hook ([#11671](https://github.com/PyTorchLightning/pytorch-lightning/pull/11671))
354357

355358

pytorch_lightning/callbacks/progress/base.py

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,34 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx):
4949

5050
def __init__(self) -> None:
5151
self._trainer: Optional["pl.Trainer"] = None
52+
self._current_eval_dataloader_idx: Optional[int] = None
5253

5354
@property
5455
def trainer(self) -> "pl.Trainer":
5556
if self._trainer is None:
5657
raise TypeError(f"The `{self.__class__.__name__}._trainer` reference has not been set yet.")
5758
return self._trainer
5859

60+
@property
61+
def sanity_check_description(self) -> str:
62+
return "Sanity Checking"
63+
64+
@property
65+
def train_description(self) -> str:
66+
return "Training"
67+
68+
@property
69+
def validation_description(self) -> str:
70+
return "Validation"
71+
72+
@property
73+
def test_description(self) -> str:
74+
return "Testing"
75+
76+
@property
77+
def predict_description(self) -> str:
78+
return "Predicting"
79+
5980
@property
6081
def train_batch_idx(self) -> int:
6182
"""The number of batches processed during training.
@@ -71,8 +92,12 @@ def val_batch_idx(self) -> int:
7192
Use this to update your progress bar.
7293
"""
7394
if self.trainer.state.fn == "fit":
74-
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.current.processed
75-
return self.trainer.validate_loop.epoch_loop.batch_progress.current.processed
95+
loop = self.trainer.fit_loop.epoch_loop.val_loop
96+
else:
97+
loop = self.trainer.validate_loop
98+
99+
current_batch_idx = loop.epoch_loop.batch_progress.current.processed
100+
return current_batch_idx
76101

77102
@property
78103
def test_batch_idx(self) -> int:
@@ -100,39 +125,55 @@ def total_train_batches(self) -> Union[int, float]:
100125
return self.trainer.num_training_batches
101126

102127
@property
103-
def total_val_batches(self) -> Union[int, float]:
104-
"""The total number of validation batches, which may change from epoch to epoch.
128+
def total_val_batches_current_dataloader(self) -> Union[int, float]:
129+
"""The total number of validation batches, which may change from epoch to epoch for current dataloader.
105130
106131
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation
107132
dataloader is of infinite size.
108133
"""
134+
assert self._current_eval_dataloader_idx is not None
109135
if self.trainer.sanity_checking:
110-
return sum(self.trainer.num_sanity_val_batches)
111-
112-
total_val_batches = 0
113-
if self.trainer.enable_validation:
114-
is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
115-
total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0
136+
return self.trainer.num_sanity_val_batches[self._current_eval_dataloader_idx]
116137

117-
return total_val_batches
138+
return self.trainer.num_val_batches[self._current_eval_dataloader_idx]
118139

119140
@property
120-
def total_test_batches(self) -> Union[int, float]:
121-
"""The total number of testing batches, which may change from epoch to epoch.
141+
def total_test_batches_current_dataloader(self) -> Union[int, float]:
142+
"""The total number of testing batches, which may change from epoch to epoch for current dataloader.
122143
123144
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is
124145
of infinite size.
125146
"""
126-
return sum(self.trainer.num_test_batches)
147+
assert self._current_eval_dataloader_idx is not None
148+
return self.trainer.num_test_batches[self._current_eval_dataloader_idx]
127149

128150
@property
129-
def total_predict_batches(self) -> Union[int, float]:
130-
"""The total number of prediction batches, which may change from epoch to epoch.
151+
def total_predict_batches_current_dataloader(self) -> Union[int, float]:
152+
"""The total number of prediction batches, which may change from epoch to epoch for current dataloader.
131153
132154
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader
133155
is of infinite size.
134156
"""
135-
return sum(self.trainer.num_predict_batches)
157+
assert self._current_eval_dataloader_idx is not None
158+
return self.trainer.num_predict_batches[self._current_eval_dataloader_idx]
159+
160+
@property
161+
def total_val_batches(self) -> Union[int, float]:
162+
"""The total number of validation batches, which may change from epoch to epoch for all val dataloaders.
163+
164+
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader
165+
is of infinite size.
166+
"""
167+
assert self._trainer is not None
168+
return sum(self.trainer.num_val_batches) if self._trainer.fit_loop.epoch_loop._should_check_val_epoch() else 0
169+
170+
def has_dataloader_changed(self, dataloader_idx: int) -> bool:
171+
old_dataloader_idx = self._current_eval_dataloader_idx
172+
self._current_eval_dataloader_idx = dataloader_idx
173+
return old_dataloader_idx != dataloader_idx
174+
175+
def reset_dataloader_idx_tracker(self) -> None:
176+
self._current_eval_dataloader_idx = None
136177

137178
def disable(self) -> None:
138179
"""You should provide a way to disable the progress bar."""

pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -262,22 +262,6 @@ def is_enabled(self) -> bool:
262262
def is_disabled(self) -> bool:
263263
return not self.is_enabled
264264

265-
@property
266-
def sanity_check_description(self) -> str:
267-
return "Validation Sanity Check"
268-
269-
@property
270-
def validation_description(self) -> str:
271-
return "Validation"
272-
273-
@property
274-
def test_description(self) -> str:
275-
return "Testing"
276-
277-
@property
278-
def predict_description(self) -> str:
279-
return "Predicting"
280-
281265
def _update_for_light_colab_theme(self) -> None:
282266
if _detect_light_colab_theme():
283267
attributes = ["description", "batch_progress", "metrics"]
@@ -354,13 +338,28 @@ def on_train_epoch_start(self, trainer, pl_module):
354338
)
355339
self.refresh()
356340

357-
def on_validation_epoch_start(self, trainer, pl_module):
341+
def on_validation_batch_start(
342+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
343+
) -> None:
344+
if not self.has_dataloader_changed(dataloader_idx):
345+
return
346+
358347
if trainer.sanity_checking:
359-
self.val_sanity_progress_bar_id = self._add_task(self.total_val_batches, self.sanity_check_description)
348+
if self.val_sanity_progress_bar_id is not None:
349+
self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False)
350+
351+
self.val_sanity_progress_bar_id = self._add_task(
352+
self.total_val_batches_current_dataloader, self.sanity_check_description, visible=False
353+
)
360354
else:
355+
if self.val_progress_bar_id is not None:
356+
self.progress.update(self.val_progress_bar_id, advance=0, visible=False)
357+
358+
# TODO: remove old tasks when new onces are created
361359
self.val_progress_bar_id = self._add_task(
362-
self.total_val_batches, self.validation_description, visible=False
360+
self.total_val_batches_current_dataloader, self.validation_description, visible=False
363361
)
362+
364363
self.refresh()
365364

366365
def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
@@ -387,13 +386,36 @@ def on_validation_epoch_end(self, trainer, pl_module):
387386
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
388387
if trainer.state.fn == "fit":
389388
self._update_metrics(trainer, pl_module)
389+
self.reset_dataloader_idx_tracker()
390+
391+
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
392+
self.reset_dataloader_idx_tracker()
390393

391-
def on_test_epoch_start(self, trainer, pl_module):
392-
self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description)
394+
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
395+
self.reset_dataloader_idx_tracker()
396+
397+
def on_test_batch_start(
398+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
399+
) -> None:
400+
if not self.has_dataloader_changed(dataloader_idx):
401+
return
402+
403+
if self.test_progress_bar_id is not None:
404+
self.progress.update(self.test_progress_bar_id, advance=0, visible=False)
405+
self.test_progress_bar_id = self._add_task(self.total_test_batches_current_dataloader, self.test_description)
393406
self.refresh()
394407

395-
def on_predict_epoch_start(self, trainer, pl_module):
396-
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)
408+
def on_predict_batch_start(
409+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
410+
) -> None:
411+
if not self.has_dataloader_changed(dataloader_idx):
412+
return
413+
414+
if self.predict_progress_bar_id is not None:
415+
self.progress.update(self.predict_progress_bar_id, advance=0, visible=False)
416+
self.predict_progress_bar_id = self._add_task(
417+
self.total_predict_batches_current_dataloader, self.predict_description
418+
)
397419
self.refresh()
398420

399421
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
@@ -406,20 +428,23 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
406428

407429
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
408430
if trainer.sanity_checking:
409-
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches)
431+
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
410432
elif self.val_progress_bar_id is not None:
411433
# check to see if we should update the main training progress bar
412434
if self.main_progress_bar_id is not None:
413-
self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches)
414-
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches)
435+
# TODO: Use total val_processed here just like TQDM in a follow-up
436+
self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
437+
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
415438
self.refresh()
416439

417440
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
418-
self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches)
441+
self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches_current_dataloader)
419442
self.refresh()
420443

421444
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
422-
self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches)
445+
self._update(
446+
self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches_current_dataloader
447+
)
423448
self.refresh()
424449

425450
def _get_train_description(self, current_epoch: int) -> str:

0 commit comments

Comments
 (0)