Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability for TQDMProgressBar to retain prior epoch training bars #19578

Merged
merged 9 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source-pytorch/common/progress_bar.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ You can update ``refresh_rate`` (rate (number of batches) at which the progress

trainer = Trainer(callbacks=[TQDMProgressBar(refresh_rate=10)])

By default the training progress bar is reset (overwritten) at each new epoch.
If you wish for a new progress bar to be displayed at the end of every epoch, set
:paramref:`TQDMProgressBar.leave <lightning.pytorch.callbacks.TQDMProgressBar.leave>` to ``True``.

.. code-block:: python

trainer = Trainer(callbacks=[TQDMProgressBar(leave=True)])

If you want to customize the default :class:`~lightning.pytorch.callbacks.TQDMProgressBar` used by Lightning, you can override
specific methods of the callback class and pass your custom implementation to the :class:`~lightning.pytorch.trainer.trainer.Trainer`.

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a flag `verbose` to the `seed_everything()` function ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))

-
- The `TQDMProgressBar` now provides an option to retain prior training epoch bars ([#19578](https://github.com/Lightning-AI/pytorch-lightning/pull/19578))

### Changed

Expand Down
9 changes: 8 additions & 1 deletion src/lightning/pytorch/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,14 @@ class TQDMProgressBar(ProgressBar):
together. This corresponds to
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.process_position` in the
:class:`~lightning.pytorch.trainer.trainer.Trainer`.
leave: If set to ``True``, leaves the finished progress bar in the terminal at the end of the epoch.
Default: ``False``

"""

BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"

def __init__(self, refresh_rate: int = 1, process_position: int = 0):
def __init__(self, refresh_rate: int = 1, process_position: int = 0, leave: bool = False):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
self._refresh_rate = self._resolve_refresh_rate(refresh_rate)
self._process_position = process_position
Expand All @@ -113,6 +115,7 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0):
self._val_progress_bar: Optional[_tqdm] = None
self._test_progress_bar: Optional[_tqdm] = None
self._predict_progress_bar: Optional[_tqdm] = None
self._leave = leave

def __getstate__(self) -> Dict:
# can't pickle the tqdm objects
Expand Down Expand Up @@ -262,6 +265,8 @@ def on_train_start(self, *_: Any) -> None:

@override
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._leave:
self.train_progress_bar = self.init_train_tqdm()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self.train_progress_bar.reset(convert_inf(self.total_train_batches))
self.train_progress_bar.initial = 0
self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
Expand All @@ -279,6 +284,8 @@ def on_train_batch_end(
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.train_progress_bar.disable:
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
if self._leave:
self.train_progress_bar.close()

@override
def on_train_end(self, *_: Any) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections import defaultdict
from typing import Union
from unittest import mock
from unittest.mock import ANY, PropertyMock, call
from unittest.mock import ANY, Mock, PropertyMock, call

import pytest
import torch
Expand Down Expand Up @@ -783,3 +783,20 @@ def test_tqdm_progress_bar_disabled_when_not_rank_zero(is_global_zero):
pbar.enable()
trainer.test(model)
assert pbar.is_disabled


@pytest.mark.parametrize("leave", [True, False])
def test_tqdm_leave(leave, tmp_path):
pbar = TQDMProgressBar(leave=leave)
pbar.init_train_tqdm = Mock(wraps=pbar.init_train_tqdm)
model = BoringModel()
trainer = Trainer(
default_root_dir=tmp_path,
callbacks=[pbar],
max_epochs=3,
limit_train_batches=1,
limit_val_batches=1,
benchmark=True,
)
trainer.fit(model)
assert pbar.init_train_tqdm.call_count == (4 if leave else 1)
Loading