Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename ProgressBarBase to ProgressBar #17058

Merged
merged 3 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ callbacks
ModelPruning
ModelSummary
OnExceptionCheckpoint
ProgressBarBase
ProgressBar
RichModelSummary
RichProgressBar
StochasticWeightAveraging
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/common/progress_bar.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Customize the progress bar
Lightning supports two different types of progress bars (`tqdm <https://github.com/tqdm/tqdm>`_ and `rich <https://github.com/Textualize/rich>`_). :class:`~lightning.pytorch.callbacks.TQDMProgressBar` is used by default,
but you can override it by passing a custom :class:`~lightning.pytorch.callbacks.TQDMProgressBar` or :class:`~lightning.pytorch.callbacks.RichProgressBar` to the ``callbacks`` argument of the :class:`~lightning.pytorch.trainer.trainer.Trainer`.

You could also use the :class:`~lightning.pytorch.callbacks.ProgressBarBase` class to implement your own progress bar.
You could also use the :class:`~lightning.pytorch.callbacks.ProgressBar` class to implement your own progress bar.

-------------

Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Lightning has a few built-in callbacks.
ModelCheckpoint
ModelPruning
ModelSummary
ProgressBarBase
ProgressBar
RichModelSummary
RichProgressBar
StochasticWeightAveraging
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/extensions/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ Modifying the Progress Bar

The progress bar by default already includes the training loss and version number of the experiment
if you are using a logger. These defaults can be customized by overriding the
:meth:`~lightning.pytorch.callbacks.progress.base.ProgressBarBase.get_metrics` hook in your logger.
:meth:`~lightning.pytorch.callbacks.progress.progress_bar.ProgressBar.get_metrics` hook in your logger.

.. code-block:: python

Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/visualize/logging_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Track and Visualize Experiments (advanced)
****************************
Change progress bar defaults
****************************
To change the default values (ie: version number) shown in the progress bar, override the :meth:`~lightning.pytorch.callbacks.progress.base.ProgressBarBase.get_metrics` method in your logger.
To change the default values (ie: version number) shown in the progress bar, override the :meth:`~lightning.pytorch.callbacks.progress.progress_bar.ProgressBar.get_metrics` method in your logger.

.. code-block:: python

Expand Down
6 changes: 3 additions & 3 deletions docs/source-pytorch/visualize/logging_expert.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,14 @@ To customize either the :class:`~lightning.pytorch.callbacks.TQDMProgressBar` o
***************************
Build your own progress bar
***************************
To build your own progress bar, subclass :class:`~lightning.pytorch.callbacks.ProgressBarBase`
To build your own progress bar, subclass :class:`~lightning.pytorch.callbacks.ProgressBar`

.. code-block:: python

from lightning.pytorch.callbacks import ProgressBarBase
from lightning.pytorch.callbacks import ProgressBar


class LitProgressBar(ProgressBarBase):
class LitProgressBar(ProgressBar):
def __init__(self):
super().__init__() # don't forget this :)
self.enable = True
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/app/cli/pl-app-template/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from lightning.app.storage import Path
from lightning.app.utilities.app_helpers import Logger
from lightning.pytorch import Callback
from lightning.pytorch.callbacks.progress.base import get_standard_metrics
from lightning.pytorch.callbacks.progress.progress_bar import get_standard_metrics
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from lightning.pytorch.utilities.parsing import collect_init_args

Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Renamed `ProgressBarBase` to `ProgressBar` ([#17058](https://github.com/Lightning-AI/lightning/pull/17058))


- The `Trainer` now chooses `accelerator="auto", strategy="auto", devices="auto"` as defaults ([#16847](https://github.com/Lightning-AI/lightning/pull/16847))

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from lightning.pytorch.callbacks.model_summary import ModelSummary
from lightning.pytorch.callbacks.on_exception_checkpoint import OnExceptionCheckpoint
from lightning.pytorch.callbacks.prediction_writer import BasePredictionWriter
from lightning.pytorch.callbacks.progress import ProgressBarBase, RichProgressBar, TQDMProgressBar
from lightning.pytorch.callbacks.progress import ProgressBar, RichProgressBar, TQDMProgressBar
from lightning.pytorch.callbacks.pruning import ModelPruning
from lightning.pytorch.callbacks.rich_model_summary import RichModelSummary
from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging
Expand All @@ -48,7 +48,7 @@
"ModelPruning",
"ModelSummary",
"OnExceptionCheckpoint",
"ProgressBarBase",
"ProgressBar",
"RichModelSummary",
"RichProgressBar",
"StochasticWeightAveraging",
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/progress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@
Use or override one of the progress bar callbacks.

"""
from lightning.pytorch.callbacks.progress.base import ProgressBarBase # noqa: F401
from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar # noqa: F401
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBar # noqa: F401
from lightning.pytorch.callbacks.progress.tqdm_progress import TQDMProgressBar # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
from lightning.pytorch.utilities.rank_zero import rank_zero_warn


class ProgressBarBase(Callback):
class ProgressBar(Callback):
r"""
The base class for progress bars in Lightning. It is a :class:`~lightning.pytorch.callbacks.Callback`
that keeps track of the batch progress in the :class:`~lightning.pytorch.trainer.trainer.Trainer`.
You should implement your highly custom progress bars with this as the base class.

Example::

class LitProgressBar(ProgressBarBase):
class LitProgressBar(ProgressBar):

def __init__(self):
super().__init__() # don't forget this :)
Expand Down
10 changes: 5 additions & 5 deletions src/lightning/pytorch/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lightning_utilities.core.imports import RequirementCache

import lightning.pytorch as pl
from lightning.pytorch.callbacks.progress.base import ProgressBarBase
from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
from lightning.pytorch.utilities.types import STEP_OUTPUT

_RICH_AVAILABLE = RequirementCache("rich>=10.2.2")
Expand All @@ -28,19 +28,19 @@
from rich import get_console, reconfigure
from rich.console import Console, RenderableType
from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn
from rich.progress_bar import ProgressBar
from rich.progress_bar import ProgressBar as _RichProgressBar
from rich.style import Style
from rich.text import Text

class CustomBarColumn(BarColumn):
"""Overrides ``BarColumn`` to provide support for dataloaders that do not define a size (infinite size)
such as ``IterableDataset``."""

def render(self, task: "Task") -> ProgressBar:
def render(self, task: "Task") -> _RichProgressBar:
"""Gets a progress bar widget for a task."""
assert task.total is not None
assert task.remaining is not None
return ProgressBar(
return _RichProgressBar(
total=max(0, task.total),
completed=max(0, task.completed),
width=None if self.bar_width is None else max(1, self.bar_width),
Expand Down Expand Up @@ -204,7 +204,7 @@ class RichProgressBarTheme:
metrics: Union[str, Style] = "white"


class RichProgressBar(ProgressBarBase):
class RichProgressBar(ProgressBar):
"""Create a progress bar with `rich text formatting <https://github.com/Textualize/rich>`_.

Install it with pip:
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from tqdm import tqdm as _tqdm

import lightning.pytorch as pl
from lightning.pytorch.callbacks.progress.base import ProgressBarBase
from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
from lightning.pytorch.utilities.rank_zero import rank_zero_debug

_PAD_SIZE = 5
Expand Down Expand Up @@ -59,7 +59,7 @@ def format_num(n: Union[int, float, str]) -> str:
return n


class TQDMProgressBar(ProgressBarBase):
class TQDMProgressBar(ProgressBar):
r"""
This is the default progress bar used by Lightning. It prints to ``stdout`` using the
:mod:`tqdm` package and shows up to four different bars:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Checkpoint,
ModelCheckpoint,
ModelSummary,
ProgressBarBase,
ProgressBar,
RichProgressBar,
TQDMProgressBar,
)
Expand Down Expand Up @@ -115,7 +115,7 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None:
self.trainer.callbacks.append(model_summary)

def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None:
progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)]
progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBar)]
if len(progress_bars) > 1:
raise MisconfigurationException(
"You added multiple progress bar callbacks to the Trainer, but currently only one"
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.accelerators import Accelerator
from lightning.pytorch.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBarBase
from lightning.pytorch.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBar
from lightning.pytorch.core.datamodule import LightningDataModule
from lightning.pytorch.loggers import Logger
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
Expand Down Expand Up @@ -1193,11 +1193,11 @@ def checkpoint_callbacks(self) -> List[Checkpoint]:
return [c for c in self.callbacks if isinstance(c, Checkpoint)]

@property
def progress_bar_callback(self) -> Optional[ProgressBarBase]:
"""An instance of :class:`~lightning.pytorch.callbacks.progress.base.ProgressBarBase` found in the
def progress_bar_callback(self) -> Optional[ProgressBar]:
"""An instance of :class:`~lightning.pytorch.callbacks.progress.progress_bar.ProgressBar` found in the
Trainer.callbacks list, or ``None`` if one doesn't exist."""
for c in self.callbacks:
if isinstance(c, ProgressBarBase):
if isinstance(c, ProgressBar):
return c
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.utils.data import DataLoader

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ProgressBarBase, RichProgressBar
from lightning.pytorch.callbacks import ProgressBar, RichProgressBar
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from lightning.pytorch.loggers import CSVLogger
Expand All @@ -31,7 +31,7 @@
def test_rich_progress_bar_callback():
trainer = Trainer(callbacks=RichProgressBar())

progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)]
progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBar)]

assert len(progress_bars) == 1
assert isinstance(trainer.progress_bar_callback, RichProgressBar)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.utils.data.dataloader import DataLoader

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, ProgressBarBase, TQDMProgressBar
from lightning.pytorch.callbacks import ModelCheckpoint, ProgressBar, TQDMProgressBar
from lightning.pytorch.callbacks.progress.tqdm_progress import Tqdm
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
Expand Down Expand Up @@ -85,15 +85,15 @@ def test_tqdm_progress_bar_on(tmpdir, pbar):
"""Test different ways the progress bar can be turned on."""
trainer = Trainer(default_root_dir=tmpdir, callbacks=pbar)

progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)]
progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBar)]
assert len(progress_bars) == 1
assert progress_bars[0] is trainer.progress_bar_callback


def test_tqdm_progress_bar_off(tmpdir):
"""Test turning the progress bar off."""
trainer = Trainer(default_root_dir=tmpdir, enable_progress_bar=False)
progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)]
progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBar)]
assert not len(progress_bars)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
LearningRateMonitor,
ModelCheckpoint,
ModelSummary,
ProgressBarBase,
ProgressBar,
TQDMProgressBar,
)
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_attach_model_callbacks():
def _attach_callbacks(trainer_callbacks, model_callbacks):
model = LightningModule()
model.configure_callbacks = lambda: model_callbacks
has_progress_bar = any(isinstance(cb, ProgressBarBase) for cb in trainer_callbacks + model_callbacks)
has_progress_bar = any(isinstance(cb, ProgressBar) for cb in trainer_callbacks + model_callbacks)
trainer = Trainer(
enable_checkpointing=False,
enable_progress_bar=has_progress_bar,
Expand Down