Skip to content

Commit

Permalink
Add ability for TQDMProgressBar to retain prior epoch training bars (L…
Browse files Browse the repository at this point in the history
…ightning-AI#19578)

Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
  • Loading branch information
2 people authored and ammyk9 committed Aug 6, 2024
1 parent e42cbf3 commit a8d3239
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 3 deletions.
8 changes: 8 additions & 0 deletions docs/source-pytorch/common/progress_bar.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ You can update ``refresh_rate`` (rate (number of batches) at which the progress
trainer = Trainer(callbacks=[TQDMProgressBar(refresh_rate=10)])
By default the training progress bar is reset (overwritten) at each new epoch.
If you wish for a new progress bar to be displayed at the end of every epoch, set
:paramref:`TQDMProgressBar.leave <lightning.pytorch.callbacks.TQDMProgressBar.leave>` to ``True``.

.. code-block:: python
trainer = Trainer(callbacks=[TQDMProgressBar(leave=True)])
If you want to customize the default :class:`~lightning.pytorch.callbacks.TQDMProgressBar` used by Lightning, you can override
specific methods of the callback class and pass your custom implementation to the :class:`~lightning.pytorch.trainer.trainer.Trainer`.

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a flag `verbose` to the `seed_everything()` function ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))

-
- The `TQDMProgressBar` now provides an option to retain prior training epoch bars ([#19578](https://github.com/Lightning-AI/pytorch-lightning/pull/19578))

### Changed

Expand Down
9 changes: 8 additions & 1 deletion src/lightning/pytorch/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,14 @@ class TQDMProgressBar(ProgressBar):
together. This corresponds to
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.process_position` in the
:class:`~lightning.pytorch.trainer.trainer.Trainer`.
leave: If set to ``True``, leaves the finished progress bar in the terminal at the end of the epoch.
Default: ``False``
"""

BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"

def __init__(self, refresh_rate: int = 1, process_position: int = 0):
def __init__(self, refresh_rate: int = 1, process_position: int = 0, leave: bool = False):
super().__init__()
self._refresh_rate = self._resolve_refresh_rate(refresh_rate)
self._process_position = process_position
Expand All @@ -113,6 +115,7 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0):
self._val_progress_bar: Optional[_tqdm] = None
self._test_progress_bar: Optional[_tqdm] = None
self._predict_progress_bar: Optional[_tqdm] = None
self._leave = leave

def __getstate__(self) -> Dict:
# can't pickle the tqdm objects
Expand Down Expand Up @@ -262,6 +265,8 @@ def on_train_start(self, *_: Any) -> None:

@override
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._leave:
self.train_progress_bar = self.init_train_tqdm()
self.train_progress_bar.reset(convert_inf(self.total_train_batches))
self.train_progress_bar.initial = 0
self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
Expand All @@ -279,6 +284,8 @@ def on_train_batch_end(
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.train_progress_bar.disable:
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
if self._leave:
self.train_progress_bar.close()

@override
def on_train_end(self, *_: Any) -> None:
Expand Down
19 changes: 18 additions & 1 deletion tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections import defaultdict
from typing import Union
from unittest import mock
from unittest.mock import ANY, PropertyMock, call
from unittest.mock import ANY, Mock, PropertyMock, call

import pytest
import torch
Expand Down Expand Up @@ -783,3 +783,20 @@ def test_tqdm_progress_bar_disabled_when_not_rank_zero(is_global_zero):
pbar.enable()
trainer.test(model)
assert pbar.is_disabled


@pytest.mark.parametrize("leave", [True, False])
def test_tqdm_leave(leave, tmp_path):
pbar = TQDMProgressBar(leave=leave)
pbar.init_train_tqdm = Mock(wraps=pbar.init_train_tqdm)
model = BoringModel()
trainer = Trainer(
default_root_dir=tmp_path,
callbacks=[pbar],
max_epochs=3,
limit_train_batches=1,
limit_val_batches=1,
benchmark=True,
)
trainer.fit(model)
assert pbar.init_train_tqdm.call_count == (4 if leave else 1)

0 comments on commit a8d3239

Please sign in to comment.