Skip to content

Commit

Permalink
Fix detection of next version in Fabric's CSVLogger (#17986)
Browse files Browse the repository at this point in the history
(cherry picked from commit 356f5d0)
  • Loading branch information
awaelchli authored and lantiga committed Jul 21, 2023
1 parent 1fae4f6 commit 1ab3da6
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 19 deletions.
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with hpu imports leading to performance degradation ([#17788](https://github.com/Lightning-AI/lightning/pull/17788))


- Fixed computing the next version folder in `CSVLogger` ([#17139](https://github.com/Lightning-AI/lightning/pull/17139), [#17139](https://github.com/Lightning-AI/lightning/pull/17986))


## [2.0.3] - 2023-06-07

- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756))
Expand Down
12 changes: 6 additions & 6 deletions src/lightning/fabric/loggers/csv_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def log_dir(self) -> str:
"""
# create a pseudo standard path
version = self.version if isinstance(self.version, str) else f"version_{self.version}"
return os.path.join(self.root_dir, self.name, version)
return os.path.join(self._root_dir, self.name, version)

@property
@rank_zero_experiment
Expand All @@ -120,7 +120,7 @@ def experiment(self) -> "_ExperimentWriter":
if self._experiment is not None:
return self._experiment

os.makedirs(self.root_dir, exist_ok=True)
os.makedirs(self._root_dir, exist_ok=True)
self._experiment = _ExperimentWriter(log_dir=self.log_dir)
return self._experiment

Expand Down Expand Up @@ -153,14 +153,14 @@ def finalize(self, status: str) -> None:
self.save()

def _get_next_version(self) -> int:
root_dir = self.root_dir
versions_root = os.path.join(self._root_dir, self.name)

if not self._fs.isdir(root_dir):
log.warning("Missing logger folder: %s", root_dir)
if not self._fs.isdir(versions_root):
log.warning("Missing logger folder: %s", versions_root)
return 0

existing_versions = []
for d in self._fs.listdir(root_dir):
for d in self._fs.listdir(versions_root):
full_path = d["name"]
name = os.path.basename(full_path)
if self._fs.isdir(full_path) and name.startswith("version_"):
Expand Down
23 changes: 10 additions & 13 deletions tests/tests_fabric/loggers/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,20 @@
from lightning.fabric.loggers.csv_logs import _ExperimentWriter


def test_file_logger_automatic_versioning(tmpdir):
def test_file_logger_automatic_versioning(tmp_path):
"""Verify that automatic versioning works."""
root_dir = tmpdir.mkdir("exp")
root_dir.mkdir("version_0")
root_dir.mkdir("version_1")
logger = CSVLogger(root_dir=root_dir, name="exp")
(tmp_path / "exp" / "version_0").mkdir(parents=True)
(tmp_path / "exp" / "version_1").mkdir()
logger = CSVLogger(root_dir=tmp_path, name="exp")
assert logger.version == 2


def test_file_logger_automatic_versioning_relative_root_dir(tmpdir, monkeypatch):
def test_file_logger_automatic_versioning_relative_root_dir(tmp_path, monkeypatch):
"""Verify that automatic versioning works, when root_dir is given a relative path."""
root_dir = tmpdir.mkdir("exp")
logs_dir = root_dir.mkdir("logs")
logs_dir.mkdir("version_0")
logs_dir.mkdir("version_1")
monkeypatch.chdir(tmpdir)
logger = CSVLogger(root_dir="exp/logs", name="logs")
(tmp_path / "exp" / "logs" / "version_0").mkdir(parents=True)
(tmp_path / "exp" / "logs" / "version_1").mkdir()
monkeypatch.chdir(tmp_path)
logger = CSVLogger(root_dir="exp", name="logs")
assert logger.version == 2


Expand Down Expand Up @@ -71,7 +68,7 @@ def test_file_logger_no_name(tmpdir, name):
logger = CSVLogger(root_dir=tmpdir, name=name)
logger.log_metrics({"a": 1})
logger.save()
assert os.path.normpath(logger.root_dir) == tmpdir # use os.path.normpath to handle trailing /
assert os.path.normpath(logger._root_dir) == tmpdir # use os.path.normpath to handle trailing /
assert os.listdir(tmpdir / "version_0")


Expand Down

0 comments on commit 1ab3da6

Please sign in to comment.