Skip to content

Commit

Permalink
Make delimiter in rich progress bar configurable (#18372)
Browse files Browse the repository at this point in the history
  • Loading branch information
quintenroets authored Aug 24, 2023
1 parent f4825e5 commit 8680dc5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/source-pytorch/common/progress_bar.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ Customize the theme for your :class:`~lightning.pytorch.callbacks.RichProgressBa
time="grey82",
processing_speed="grey82",
metrics="grey82",
metrics_text_delimiter="\n",
)
)
Expand Down
42 changes: 33 additions & 9 deletions src/lightning/pytorch/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,13 @@ def render(self, task: "Task") -> RenderableType:
class MetricsTextColumn(ProgressColumn):
"""A column containing text."""

def __init__(self, trainer: "pl.Trainer", style: Union[str, "Style"]):
def __init__(self, trainer: "pl.Trainer", style: Union[str, "Style"], text_delimiter: str):
self._trainer = trainer
self._tasks: Dict[Union[int, TaskID], Any] = {}
self._current_task_id = 0
self._metrics: Dict[Union[str, "Style"], Any] = {}
self._style = style
self._text_delimiter = text_delimiter
super().__init__()

def update(self, metrics: Dict[Any, Any]) -> None:
Expand All @@ -167,7 +168,8 @@ def render(self, task: "Task") -> Text:
if self._trainer.training and task.id != self._current_task_id:
return self._tasks[task.id]

text = " ".join(self._generate_metrics_texts())
metrics_texts = self._generate_metrics_texts()
text = self._text_delimiter.join(metrics_texts)
return Text(text, justify="left", style=self._style)

def _generate_metrics_texts(self) -> Generator[str, None, None]:
Expand Down Expand Up @@ -204,6 +206,7 @@ class RichProgressBarTheme:
time: Union[str, Style] = "grey54"
processing_speed: Union[str, Style] = "grey70"
metrics: Union[str, Style] = "white"
metrics_text_delimiter: str = " "


class RichProgressBar(ProgressBar):
Expand Down Expand Up @@ -325,7 +328,7 @@ def _init_progress(self, trainer: "pl.Trainer") -> None:
reconfigure(**self._console_kwargs)
self._console = get_console()
self._console.clear_live()
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics)
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics, self.theme.metrics_text_delimiter)
self.progress = CustomProgress(
*self.configure_columns(trainer),
self._metric_component,
Expand Down Expand Up @@ -376,7 +379,10 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
self.train_progress_bar_id = self._add_task(total_batches, train_description)
else:
self.progress.reset(
self.train_progress_bar_id, total=total_batches, description=train_description, visible=True
self.train_progress_bar_id,
total=total_batches,
description=train_description,
visible=True,
)

self.refresh()
Expand All @@ -399,22 +405,30 @@ def on_validation_batch_start(
self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False)

self.val_sanity_progress_bar_id = self._add_task(
self.total_val_batches_current_dataloader, self.sanity_check_description, visible=False
self.total_val_batches_current_dataloader,
self.sanity_check_description,
visible=False,
)
else:
if self.val_progress_bar_id is not None:
self.progress.update(self.val_progress_bar_id, advance=0, visible=False)

# TODO: remove old tasks when new onces are created
self.val_progress_bar_id = self._add_task(
self.total_val_batches_current_dataloader, self.validation_description, visible=False
self.total_val_batches_current_dataloader,
self.validation_description,
visible=False,
)

self.refresh()

def _add_task(self, total_batches: Union[int, float], description: str, visible: bool = True) -> "TaskID":
assert self.progress is not None
return self.progress.add_task(f"[{self.theme.description}]{description}", total=total_batches, visible=visible)
return self.progress.add_task(
f"[{self.theme.description}]{description}",
total=total_batches,
visible=visible,
)

def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None:
if self.progress is not None and self.is_enabled:
Expand Down Expand Up @@ -486,7 +500,12 @@ def on_predict_batch_start(
self.refresh()

def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
self._update(self.train_progress_bar_id, batch_idx + 1)
self._update_metrics(trainer, pl_module)
Expand Down Expand Up @@ -573,7 +592,12 @@ def _update_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
self._stop_progress()

def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
def on_exception(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
exception: BaseException,
) -> None:
self._stop_progress()

def configure_columns(self, trainer: "pl.Trainer") -> list:
Expand Down

0 comments on commit 8680dc5

Please sign in to comment.