From 1ab3da69a717d84b30726e5aa171cf9e9ca183c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 14 Jul 2023 22:08:16 +0200 Subject: [PATCH] Fix detection of next version in Fabric's CSVLogger (#17986) (cherry picked from commit 356f5d0c657faf128182086fa3b1ec5d929d2ca5) --- src/lightning/fabric/CHANGELOG.md | 3 +++ src/lightning/fabric/loggers/csv_logs.py | 12 ++++++------ tests/tests_fabric/loggers/test_csv.py | 23 ++++++++++------------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 715e2e3261c4b..346740311dcd2 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -44,6 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with hpu imports leading to performance degradation ([#17788](https://github.com/Lightning-AI/lightning/pull/17788)) +- Fixed computing the next version folder in `CSVLogger` ([#17139](https://github.com/Lightning-AI/lightning/pull/17139), [#17139](https://github.com/Lightning-AI/lightning/pull/17986)) + + ## [2.0.3] - 2023-06-07 - Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756)) diff --git a/src/lightning/fabric/loggers/csv_logs.py b/src/lightning/fabric/loggers/csv_logs.py index 312e48e6c2e76..31756502e95f7 100644 --- a/src/lightning/fabric/loggers/csv_logs.py +++ b/src/lightning/fabric/loggers/csv_logs.py @@ -105,7 +105,7 @@ def log_dir(self) -> str: """ # create a pseudo standard path version = self.version if isinstance(self.version, str) else f"version_{self.version}" - return os.path.join(self.root_dir, self.name, version) + return os.path.join(self._root_dir, self.name, version) @property @rank_zero_experiment @@ -120,7 +120,7 @@ def experiment(self) -> "_ExperimentWriter": if self._experiment is not None: return self._experiment - os.makedirs(self.root_dir, exist_ok=True) + os.makedirs(self._root_dir, exist_ok=True) self._experiment = _ExperimentWriter(log_dir=self.log_dir) return self._experiment @@ -153,14 +153,14 @@ def finalize(self, status: str) -> None: self.save() def _get_next_version(self) -> int: - root_dir = self.root_dir + versions_root = os.path.join(self._root_dir, self.name) - if not self._fs.isdir(root_dir): - log.warning("Missing logger folder: %s", root_dir) + if not self._fs.isdir(versions_root): + log.warning("Missing logger folder: %s", versions_root) return 0 existing_versions = [] - for d in self._fs.listdir(root_dir): + for d in self._fs.listdir(versions_root): full_path = d["name"] name = os.path.basename(full_path) if self._fs.isdir(full_path) and name.startswith("version_"): diff --git a/tests/tests_fabric/loggers/test_csv.py b/tests/tests_fabric/loggers/test_csv.py index 045a9aa67e848..98d21ae1b7c08 100644 --- a/tests/tests_fabric/loggers/test_csv.py +++ b/tests/tests_fabric/loggers/test_csv.py @@ -21,23 +21,20 @@ from lightning.fabric.loggers.csv_logs import _ExperimentWriter -def test_file_logger_automatic_versioning(tmpdir): +def test_file_logger_automatic_versioning(tmp_path): """Verify that automatic versioning works.""" - root_dir = tmpdir.mkdir("exp") - root_dir.mkdir("version_0") - root_dir.mkdir("version_1") - logger = CSVLogger(root_dir=root_dir, name="exp") + (tmp_path / "exp" / "version_0").mkdir(parents=True) + (tmp_path / "exp" / "version_1").mkdir() + logger = CSVLogger(root_dir=tmp_path, name="exp") assert logger.version == 2 -def test_file_logger_automatic_versioning_relative_root_dir(tmpdir, monkeypatch): +def test_file_logger_automatic_versioning_relative_root_dir(tmp_path, monkeypatch): """Verify that automatic versioning works, when root_dir is given a relative path.""" - root_dir = tmpdir.mkdir("exp") - logs_dir = root_dir.mkdir("logs") - logs_dir.mkdir("version_0") - logs_dir.mkdir("version_1") - monkeypatch.chdir(tmpdir) - logger = CSVLogger(root_dir="exp/logs", name="logs") + (tmp_path / "exp" / "logs" / "version_0").mkdir(parents=True) + (tmp_path / "exp" / "logs" / "version_1").mkdir() + monkeypatch.chdir(tmp_path) + logger = CSVLogger(root_dir="exp", name="logs") assert logger.version == 2 @@ -71,7 +68,7 @@ def test_file_logger_no_name(tmpdir, name): logger = CSVLogger(root_dir=tmpdir, name=name) logger.log_metrics({"a": 1}) logger.save() - assert os.path.normpath(logger.root_dir) == tmpdir # use os.path.normpath to handle trailing / + assert os.path.normpath(logger._root_dir) == tmpdir # use os.path.normpath to handle trailing / assert os.listdir(tmpdir / "version_0")