Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CSVLogger trying to append to file from previous run in same version folder #19446

Merged
merged 7 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually ([#19446](https://github.com/Lightning-AI/lightning/pull/19446))

-

Expand Down
21 changes: 14 additions & 7 deletions src/lightning/fabric/loggers/csv_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class CSVLogger(Logger):
name: Experiment name. Defaults to ``'lightning_logs'``.
version: Experiment version. If version is not specified the logger inspects the save
directory for existing versions, then automatically assigns the next available version.
If the version is specified, and the directory already contains a metrics file for that version, it will be
overwritten.
prefix: A string to put at the beginning of metric keys.
flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps).

Expand Down Expand Up @@ -203,15 +205,11 @@ def __init__(self, log_dir: str) -> None:

self._fs = get_filesystem(log_dir)
self.log_dir = log_dir
if self._fs.exists(self.log_dir) and self._fs.listdir(self.log_dir):
rank_zero_warn(
f"Experiment logs directory {self.log_dir} exists and is not empty."
" Previous log files in this directory will be deleted when the new ones are saved!"
)
self._fs.makedirs(self.log_dir, exist_ok=True)

self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE)

self._check_log_dir_exists()
self._fs.makedirs(self.log_dir, exist_ok=True)

def log_metrics(self, metrics_dict: Dict[str, float], step: Optional[int] = None) -> None:
"""Record metrics."""

Expand Down Expand Up @@ -264,3 +262,12 @@ def _rewrite_with_new_header(self, fieldnames: List[str]) -> None:
writer = csv.DictWriter(file, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(metrics)

def _check_log_dir_exists(self) -> None:
if self._fs.exists(self.log_dir) and self._fs.listdir(self.log_dir):
rank_zero_warn(
f"Experiment logs directory {self.log_dir} exists and is not empty."
" Previous log files in this directory will be deleted when the new ones are saved!"
)
if self._fs.isfile(self.metrics_file_path):
self._fs.rm_file(self.metrics_file_path)
3 changes: 2 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually ([#19446](https://github.com/Lightning-AI/lightning/pull/19446))


-

Expand Down
28 changes: 26 additions & 2 deletions tests/tests_fabric/loggers/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -50,6 +51,21 @@ def test_manual_versioning(tmp_path):
assert logger.version == 1


def test_manual_versioning_file_exists(tmp_path):
"""Test that a warning is emitted and existing files get overwritten."""

# Simulate an existing 'version_0' vrom a previous run
(tmp_path / "exp" / "version_0").mkdir(parents=True)
previous_metrics_file = tmp_path / "exp" / "version_0" / "metrics.csv"
previous_metrics_file.touch()

logger = CSVLogger(root_dir=tmp_path, name="exp", version=0)
assert previous_metrics_file.exists()
with pytest.warns(UserWarning, match="Experiment logs directory .* exists and is not empty"):
_ = logger.experiment
assert not previous_metrics_file.exists()


def test_named_version(tmp_path):
"""Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'."""
exp_name = "exp"
Expand Down Expand Up @@ -130,7 +146,11 @@ def test_automatic_step_tracking(tmp_path):
assert logger.experiment.metrics[2]["step"] == 2


def test_append_metrics_file(tmp_path):
@mock.patch(
# Mock the existance check, so we can simulate appending to the metrics file
"lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists"
)
def test_append_metrics_file(_, tmp_path):
"""Test that the logger appends to the file instead of rewriting it on every save."""
logger = CSVLogger(tmp_path, name="test", version=0, flush_logs_every_n_steps=1)

Expand Down Expand Up @@ -167,7 +187,11 @@ def test_append_columns(tmp_path):
assert set(header.split(",")) == {"step", "a", "b", "c"}


def test_rewrite_with_new_header(tmp_path):
@mock.patch(
# Mock the existance check, so we can simulate appending to the metrics file
"lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists"
)
def test_rewrite_with_new_header(_, tmp_path):
# write a csv file manually
with open(tmp_path / "metrics.csv", "w") as file:
file.write("step,metric1,metric2\n")
Expand Down
20 changes: 20 additions & 0 deletions tests/tests_pytorch/loggers/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
from unittest.mock import MagicMock

import fsspec
Expand Down Expand Up @@ -51,6 +52,21 @@ def test_manual_versioning(tmp_path):
assert logger.version == 1


def test_manual_versioning_file_exists(tmp_path):
"""Test that a warning is emitted and existing files get overwritten."""

# Simulate an existing 'version_0' vrom a previous run
(tmp_path / "exp" / "version_0").mkdir(parents=True)
previous_metrics_file = tmp_path / "exp" / "version_0" / "metrics.csv"
previous_metrics_file.touch()

logger = CSVLogger(save_dir=tmp_path, name="exp", version=0)
assert previous_metrics_file.exists()
with pytest.warns(UserWarning, match="Experiment logs directory .* exists and is not empty"):
_ = logger.experiment
assert not previous_metrics_file.exists()


def test_named_version(tmp_path):
"""Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'."""
exp_name = "exp"
Expand Down Expand Up @@ -148,6 +164,10 @@ def test_metrics_reset_after_save(tmp_path):
assert not logger.experiment.metrics


@mock.patch(
# Mock the existance check, so we can simulate appending to the metrics file
"lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists"
)
def test_append_metrics_file(tmp_path):
"""Test that the logger appends to the file instead of rewriting it on every save."""
logger = CSVLogger(tmp_path, name="test", version=0, flush_logs_every_n_steps=1)
Expand Down
Loading