Skip to content

Commit

Permalink
Use a single instance of rich.console.Console throughout the codeba…
Browse files Browse the repository at this point in the history
…se (#12886)
  • Loading branch information
otaj authored Apr 27, 2022
1 parent cac02a0 commit a414862
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 25 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed support for `ModelCheckpoint` monitors with dots ([#12783](https://github.com/PyTorchLightning/pytorch-lightning/pull/12783))


- Use only a single instance of `rich.console.Console` throughout codebase ([#12886](https://github.com/PyTorchLightning/pytorch-lightning/pull/12886))


## [1.6.1] - 2022-04-13

### Changed
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

Task, Style = None, None
if _RICH_AVAILABLE:
from rich.console import Console, RenderableType
from rich import get_console, reconfigure
from rich.console import RenderableType
from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn
from rich.progress_bar import ProgressBar
from rich.style import Style
Expand Down Expand Up @@ -278,7 +279,8 @@ def enable(self) -> None:
def _init_progress(self, trainer):
if self.is_enabled and (self.progress is None or self._progress_stopped):
self._reset_progress_bar_ids()
self._console = Console(**self._console_kwargs)
reconfigure(**self._console_kwargs)
self._console = get_console()
self._console.clear_live()
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics)
self.progress = CustomProgress(
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pytorch_lightning.utilities.model_summary import get_human_readable_count

if _RICH_AVAILABLE:
from rich.console import Console
from rich import get_console
from rich.table import Table


Expand Down Expand Up @@ -73,7 +73,7 @@ def summarize(
model_size: float,
) -> None:

console = Console()
console = get_console()

table = Table(header_style="bold magenta")
table.add_column(" ", style="dim")
Expand Down
20 changes: 8 additions & 12 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sys
from collections import ChainMap, OrderedDict
from functools import partial
from typing import Any, IO, Iterable, List, Optional, Sequence, Type, Union
from typing import Any, Iterable, List, Optional, Sequence, Type, Union

import torch
from deprecate.utils import void
Expand All @@ -42,7 +42,7 @@
from pytorch_lightning.utilities.types import EPOCH_OUTPUT

if _RICH_AVAILABLE:
from rich.console import Console
from rich import get_console
from rich.table import Column, Table


Expand Down Expand Up @@ -319,11 +319,7 @@ def _find_value(data: dict, target: str) -> Iterable[Any]:
yield from EvaluationLoop._find_value(v, target)

@staticmethod
def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] = None) -> None:
# print to stdout by default
if file is None:
file = sys.stdout

def _print_results(results: List[_OUT_DICT], stage: str) -> None:
# remove the dl idx suffix
results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results]
metrics = sorted({k for keys in apply_to_collection(results, dict, EvaluationLoop._get_keys) for k in keys})
Expand Down Expand Up @@ -358,24 +354,24 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]]
table_headers.insert(0, f"{stage} Metric".capitalize())

if _RICH_AVAILABLE:
console = Console(file=file)

columns = [Column(h, justify="center", style="magenta", width=max_length) for h in table_headers]
columns[0].style = "cyan"

table = Table(*columns)
for metric, row in zip(metrics, table_rows):
row.insert(0, metric)
table.add_row(*row)

console = get_console()
console.print(table)
else:
row_format = f"{{:^{max_length}}}" * len(table_headers)
half_term_size = int(term_size / 2)

try:
# some terminals do not support this character
if hasattr(file, "encoding") and file.encoding is not None:
"─".encode(file.encoding)
if sys.stdout.encoding is not None:
"─".encode(sys.stdout.encoding)
except UnicodeEncodeError:
bar_character = "-"
else:
Expand All @@ -394,7 +390,7 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]]
else:
lines.append(row_format.format(metric, *row).rstrip())
lines.append(bar)
print(os.linesep.join(lines), file=file)
print(os.linesep.join(lines))


def _select_data_fetcher_type(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]:
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def test_rich_progress_bar_import_error(monkeypatch):


@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.rich_model_summary.Console.print", autospec=True)
@mock.patch("pytorch_lightning.callbacks.rich_model_summary.Table.add_row", autospec=True)
@mock.patch("rich.console.Console.print", autospec=True)
@mock.patch("rich.table.Table.add_row", autospec=True)
def test_rich_summary_tuples(mock_table_add_row, mock_console):
"""Ensure that tuples are converted into string, and print is called correctly."""
model_summary = RichModelSummary()
Expand Down
21 changes: 14 additions & 7 deletions tests/trainer/logging_/test_eval_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import collections
import itertools
import os
from contextlib import redirect_stdout
from io import StringIO
from unittest import mock
from unittest.mock import call
Expand All @@ -28,10 +29,13 @@
from pytorch_lightning.loops.dataloader import EvaluationLoop
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _RICH_AVAILABLE
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf

if _RICH_AVAILABLE:
from rich import get_console


def test__validation_step__log(tmpdir):
"""Tests that validation_step can log."""
Expand Down Expand Up @@ -864,8 +868,9 @@ def test_native_print_results(monkeypatch, inputs, expected):
import pytorch_lightning.loops.dataloader.evaluation_loop as imports

monkeypatch.setattr(imports, "_RICH_AVAILABLE", False)
out = StringIO()
EvaluationLoop._print_results(*inputs, file=out)

with redirect_stdout(StringIO()) as out:
EvaluationLoop._print_results(*inputs)
expected = expected[1:] # remove the initial line break from the """ string
assert out.getvalue().replace(os.linesep, "\n") == expected.lstrip()

Expand All @@ -878,7 +883,8 @@ def test_native_print_results_encodings(monkeypatch, encoding):

out = mock.Mock()
out.encoding = encoding
EvaluationLoop._print_results(*inputs0, file=out)
with redirect_stdout(out) as out:
EvaluationLoop._print_results(*inputs0)

# Attempt to encode everything the file is told to write with the given encoding
for call_ in out.method_calls:
Expand Down Expand Up @@ -950,7 +956,8 @@ def test_native_print_results_encodings(monkeypatch, encoding):
)
@RunIf(skip_windows=True, rich=True)
def test_rich_print_results(inputs, expected):
out = StringIO()
EvaluationLoop._print_results(*inputs, file=out)
console = get_console()
with console.capture() as capture:
EvaluationLoop._print_results(*inputs)
expected = expected[1:] # remove the initial line break from the """ string
assert out.getvalue() == expected.lstrip()
assert capture.get() == expected.lstrip()

0 comments on commit a414862

Please sign in to comment.