From 4bead5c1b43df864a459047682f49f533843e5c7 Mon Sep 17 00:00:00 2001 From: Quinten Date: Tue, 22 Aug 2023 16:35:30 +0200 Subject: [PATCH 1/4] optimize render function --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 48aee9673e4c4..652dc43b3958f 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -167,11 +167,13 @@ def render(self, task: "Task") -> Text: if self._trainer.training and task.id != self._current_task_id: return self._tasks[task.id] - text = "" - for k, v in self._metrics.items(): - text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " + text = "".join(self.generate_metrics_texts()) return Text(text, justify="left", style=self._style) + def generate_metrics_texts(self): + for k, v in self._metrics.items(): + yield f"{k}: {round(v, 3) if isinstance(v, float) else v} " + else: Task, Style = Any, Any # type: ignore[assignment, misc] From 24fdfed90e789b07393fc6fe22ae590a726b1377 Mon Sep 17 00:00:00 2001 From: Quinten Date: Wed, 23 Aug 2023 01:10:30 +0200 Subject: [PATCH 2/4] move space to joining delimiter --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 652dc43b3958f..444b224fbbc24 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -167,12 +167,12 @@ def render(self, task: "Task") -> Text: if self._trainer.training and task.id != self._current_task_id: return self._tasks[task.id] - text = "".join(self.generate_metrics_texts()) + text = " ".join(self.generate_metrics_texts()) return Text(text, justify="left", style=self._style) def generate_metrics_texts(self): for k, v in self._metrics.items(): - yield f"{k}: {round(v, 3) if isinstance(v, float) else v} " + yield f"{k}: {round(v, 3) if isinstance(v, float) else v}" else: Task, Style = Any, Any # type: ignore[assignment, misc] From ac1c8b14af71e358206553a7d166e2d07cbfc918 Mon Sep 17 00:00:00 2001 From: Quinten Date: Wed, 23 Aug 2023 01:11:14 +0200 Subject: [PATCH 3/4] make metric generation function private --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 444b224fbbc24..14587ac723468 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -167,10 +167,10 @@ def render(self, task: "Task") -> Text: if self._trainer.training and task.id != self._current_task_id: return self._tasks[task.id] - text = " ".join(self.generate_metrics_texts()) + text = " ".join(self._generate_metrics_texts()) return Text(text, justify="left", style=self._style) - def generate_metrics_texts(self): + def _generate_metrics_texts(self): for k, v in self._metrics.items(): yield f"{k}: {round(v, 3) if isinstance(v, float) else v}" From eaac0ae059bf8dcabc3a9231261dea71c0f40722 Mon Sep 17 00:00:00 2001 From: Quinten Date: Wed, 23 Aug 2023 01:32:53 +0200 Subject: [PATCH 4/4] add return type annotation --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 14587ac723468..76b6e94df747d 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass from datetime import timedelta -from typing import Any, cast, Dict, Optional, Union +from typing import Any, cast, Dict, Generator, Optional, Union from lightning_utilities.core.imports import RequirementCache @@ -170,7 +170,7 @@ def render(self, task: "Task") -> Text: text = " ".join(self._generate_metrics_texts()) return Text(text, justify="left", style=self._style) - def _generate_metrics_texts(self): + def _generate_metrics_texts(self) -> Generator[str, None, None]: for k, v in self._metrics.items(): yield f"{k}: {round(v, 3) if isinstance(v, float) else v}"