diff --git a/docs/source-pytorch/common/progress_bar.rst b/docs/source-pytorch/common/progress_bar.rst index 385d17eb4e43a..24c4285cdfc5c 100644 --- a/docs/source-pytorch/common/progress_bar.rst +++ b/docs/source-pytorch/common/progress_bar.rst @@ -92,6 +92,7 @@ Customize the theme for your :class:`~lightning.pytorch.callbacks.RichProgressBa time="grey82", processing_speed="grey82", metrics="grey82", + metrics_text_delimiter="\n", ) ) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 76b6e94df747d..0bdfcda87c038 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -136,12 +136,13 @@ def render(self, task: "Task") -> RenderableType: class MetricsTextColumn(ProgressColumn): """A column containing text.""" - def __init__(self, trainer: "pl.Trainer", style: Union[str, "Style"]): + def __init__(self, trainer: "pl.Trainer", style: Union[str, "Style"], text_delimiter: str): self._trainer = trainer self._tasks: Dict[Union[int, TaskID], Any] = {} self._current_task_id = 0 self._metrics: Dict[Union[str, "Style"], Any] = {} self._style = style + self._text_delimiter = text_delimiter super().__init__() def update(self, metrics: Dict[Any, Any]) -> None: @@ -167,7 +168,8 @@ def render(self, task: "Task") -> Text: if self._trainer.training and task.id != self._current_task_id: return self._tasks[task.id] - text = " ".join(self._generate_metrics_texts()) + metrics_texts = self._generate_metrics_texts() + text = self._text_delimiter.join(metrics_texts) return Text(text, justify="left", style=self._style) def _generate_metrics_texts(self) -> Generator[str, None, None]: @@ -204,6 +206,7 @@ class RichProgressBarTheme: time: Union[str, Style] = "grey54" processing_speed: Union[str, Style] = "grey70" metrics: Union[str, Style] = "white" + metrics_text_delimiter: str = " " class RichProgressBar(ProgressBar): @@ -325,7 +328,7 @@ def _init_progress(self, trainer: "pl.Trainer") -> None: reconfigure(**self._console_kwargs) self._console = get_console() self._console.clear_live() - self._metric_component = MetricsTextColumn(trainer, self.theme.metrics) + self._metric_component = MetricsTextColumn(trainer, self.theme.metrics, self.theme.metrics_text_delimiter) self.progress = CustomProgress( *self.configure_columns(trainer), self._metric_component, @@ -376,7 +379,10 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo self.train_progress_bar_id = self._add_task(total_batches, train_description) else: self.progress.reset( - self.train_progress_bar_id, total=total_batches, description=train_description, visible=True + self.train_progress_bar_id, + total=total_batches, + description=train_description, + visible=True, ) self.refresh() @@ -399,7 +405,9 @@ def on_validation_batch_start( self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) self.val_sanity_progress_bar_id = self._add_task( - self.total_val_batches_current_dataloader, self.sanity_check_description, visible=False + self.total_val_batches_current_dataloader, + self.sanity_check_description, + visible=False, ) else: if self.val_progress_bar_id is not None: @@ -407,14 +415,20 @@ def on_validation_batch_start( # TODO: remove old tasks when new onces are created self.val_progress_bar_id = self._add_task( - self.total_val_batches_current_dataloader, self.validation_description, visible=False + self.total_val_batches_current_dataloader, + self.validation_description, + visible=False, ) self.refresh() def _add_task(self, total_batches: Union[int, float], description: str, visible: bool = True) -> "TaskID": assert self.progress is not None - return self.progress.add_task(f"[{self.theme.description}]{description}", total=total_batches, visible=visible) + return self.progress.add_task( + f"[{self.theme.description}]{description}", + total=total_batches, + visible=visible, + ) def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None: if self.progress is not None and self.is_enabled: @@ -486,7 +500,12 @@ def on_predict_batch_start( self.refresh() def on_train_batch_end( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, ) -> None: self._update(self.train_progress_bar_id, batch_idx + 1) self._update_metrics(trainer, pl_module) @@ -573,7 +592,12 @@ def _update_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: self._stop_progress() - def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: + def on_exception( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + exception: BaseException, + ) -> None: self._stop_progress() def configure_columns(self, trainer: "pl.Trainer") -> list: