Skip to content

Commit ac9c948

Browse files
rohitgr7carmocca
andcommitted
Fix to avoid val progress bar disappear after validate (#11700)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
1 parent 7c51811 commit ac9c948

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Fixed an issue to avoid validation loop run on restart ([#11552](https://github.com/PyTorchLightning/pytorch-lightning/pull/11552))
1313
- The Rich progress bar now correctly shows the `on_epoch` logged values on train epoch end ([#11689](https://github.com/PyTorchLightning/pytorch-lightning/pull/11689))
1414
- Fixed check for available modules ([#11526](https://github.com/PyTorchLightning/pytorch-lightning/pull/11526))
15+
- Fixed an issue to avoid val bar disappear after `trainer.validate()` ([#11700](https://github.com/PyTorchLightning/pytorch-lightning/pull/11700))
1516

1617

1718
## [1.5.9] - 2022-01-18

pytorch_lightning/callbacks/progress/tqdm_progress.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,12 @@ def init_predict_tqdm(self) -> Tqdm:
183183
def init_validation_tqdm(self) -> Tqdm:
184184
"""Override this to customize the tqdm bar for validation."""
185185
# The main progress bar doesn't exist in `trainer.validate()`
186-
has_main_bar = self.main_progress_bar is not None
186+
has_main_bar = self.trainer.state.fn != "validate"
187187
bar = Tqdm(
188188
desc="Validating",
189189
position=(2 * self.process_position + has_main_bar),
190190
disable=self.is_disabled,
191-
leave=False,
191+
leave=not has_main_bar,
192192
dynamic_ncols=True,
193193
file=sys.stdout,
194194
)

tests/callbacks/test_tqdm_progress_bar.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,12 @@ def test_tqdm_progress_bar_totals(tmpdir):
104104
m = bar.total_val_batches
105105
assert len(trainer.train_dataloader) == n
106106
assert bar.main_progress_bar.total == n + m
107+
assert bar.main_progress_bar.leave
107108

108109
# check val progress bar total
109110
assert sum(len(loader) for loader in trainer.val_dataloaders) == m
110111
assert bar.val_progress_bar.total == m
112+
assert not bar.val_progress_bar.leave
111113

112114
# main progress bar should have reached the end (train batches + val batches)
113115
assert bar.main_progress_bar.n == n + m
@@ -126,13 +128,15 @@ def test_tqdm_progress_bar_totals(tmpdir):
126128
assert bar.val_progress_bar.total == m
127129
assert bar.val_progress_bar.n == m
128130
assert bar.val_batch_idx == m
131+
assert bar.val_progress_bar.leave
129132

130133
trainer.test(model)
131134

132135
# check test progress bar total
133136
k = bar.total_test_batches
134137
assert sum(len(loader) for loader in trainer.test_dataloaders) == k
135138
assert bar.test_progress_bar.total == k
139+
assert bar.test_progress_bar.leave
136140

137141
# test progress bar should have reached the end
138142
assert bar.test_progress_bar.n == k

0 commit comments

Comments
 (0)