Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
akashkw committed Feb 25, 2022
1 parent 6cf1eb4 commit c277c90
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pytorch_lightning/utilities/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,17 @@ def _add_prefix(metrics: Dict[str, float], prefix: str, separator: str) -> Dict[
return metrics


def _name(loggers: List[Any]) -> str:
def _name(loggers: List[Any], separator: str = "_") -> str:
if len(loggers) == 1:
return loggers[0].name
else:
# Concatenate names together, removing duplicates and preserving order
return "_".join(dict.fromkeys(str(logger.name) for logger in loggers))
return separator.join(dict.fromkeys(str(logger.name) for logger in loggers))


def _version(loggers: List[Any]) -> Union[int, str]:
def _version(loggers: List[Any], separator: str = "_") -> Union[int, str]:
if len(loggers) == 1:
return loggers[0].version
else:
# Concatenate versions together, removing duplicates and preserving order
return "_".join(dict.fromkeys(str(logger.version) for logger in loggers))
return separator.join(dict.fromkeys(str(logger.version) for logger in loggers))
37 changes: 37 additions & 0 deletions tests/utilities/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.utilities.logger import (
_add_prefix,
_convert_params,
_flatten_dict,
_name,
_sanitize_callable_params,
_sanitize_params,
_version,
)


Expand Down Expand Up @@ -172,3 +175,37 @@ def test_add_prefix():
assert "prefix-metric2" not in metrics
assert metrics["prefix2_prefix-metric1"] == 1
assert metrics["prefix2_prefix-metric2"] == 2


def test_name(tmpdir):
"""Verify names of loggers are concatenated properly."""
logger1 = CSVLogger(tmpdir, name="foo")
logger2 = CSVLogger(tmpdir, name="bar")
logger3 = CSVLogger(tmpdir, name="foo")
logger4 = CSVLogger(tmpdir, name="baz")
loggers = [logger1, logger2, logger3, logger4]
name = _name([])
assert name == ""
name = _name([logger3])
assert name == "foo"
name = _name(loggers)
assert name == "foo_bar_baz"
name = _name(loggers, "-")
assert name == "foo-bar-baz"


def test_version(tmpdir):
"""Verify names of loggers are concatenated properly."""
logger1 = CSVLogger(tmpdir, version=0)
logger2 = CSVLogger(tmpdir, version=2)
logger3 = CSVLogger(tmpdir, version=1)
logger4 = CSVLogger(tmpdir, version=0)
loggers = [logger1, logger2, logger3, logger4]
version = _version([])
assert version == ""
version = _version([logger3])
assert version == 1
version = _version(loggers)
assert version == "0_2_1"
version = _version(loggers, "-")
assert version == "0-2-1"

0 comments on commit c277c90

Please sign in to comment.