From c277c90f12ce901d7e27cf3deb54706a9fe3364f Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Fri, 25 Feb 2022 13:55:19 -0800 Subject: [PATCH] Add unit tests --- pytorch_lightning/utilities/logger.py | 8 +++--- tests/utilities/test_logger.py | 37 +++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/logger.py b/pytorch_lightning/utilities/logger.py index 2bb541819eb1a..ef27761a2ec6e 100644 --- a/pytorch_lightning/utilities/logger.py +++ b/pytorch_lightning/utilities/logger.py @@ -148,17 +148,17 @@ def _add_prefix(metrics: Dict[str, float], prefix: str, separator: str) -> Dict[ return metrics -def _name(loggers: List[Any]) -> str: +def _name(loggers: List[Any], separator: str = "_") -> str: if len(loggers) == 1: return loggers[0].name else: # Concatenate names together, removing duplicates and preserving order - return "_".join(dict.fromkeys(str(logger.name) for logger in loggers)) + return separator.join(dict.fromkeys(str(logger.name) for logger in loggers)) -def _version(loggers: List[Any]) -> Union[int, str]: +def _version(loggers: List[Any], separator: str = "_") -> Union[int, str]: if len(loggers) == 1: return loggers[0].version else: # Concatenate versions together, removing duplicates and preserving order - return "_".join(dict.fromkeys(str(logger.version) for logger in loggers)) + return separator.join(dict.fromkeys(str(logger.version) for logger in loggers)) diff --git a/tests/utilities/test_logger.py b/tests/utilities/test_logger.py index 8d9b495fb96bf..02d2795defb50 100644 --- a/tests/utilities/test_logger.py +++ b/tests/utilities/test_logger.py @@ -17,12 +17,15 @@ import torch from pytorch_lightning import Trainer +from pytorch_lightning.loggers import CSVLogger from pytorch_lightning.utilities.logger import ( _add_prefix, _convert_params, _flatten_dict, + _name, _sanitize_callable_params, _sanitize_params, + _version, ) @@ -172,3 +175,37 @@ def test_add_prefix(): assert "prefix-metric2" not in metrics assert metrics["prefix2_prefix-metric1"] == 1 assert metrics["prefix2_prefix-metric2"] == 2 + + +def test_name(tmpdir): + """Verify names of loggers are concatenated properly.""" + logger1 = CSVLogger(tmpdir, name="foo") + logger2 = CSVLogger(tmpdir, name="bar") + logger3 = CSVLogger(tmpdir, name="foo") + logger4 = CSVLogger(tmpdir, name="baz") + loggers = [logger1, logger2, logger3, logger4] + name = _name([]) + assert name == "" + name = _name([logger3]) + assert name == "foo" + name = _name(loggers) + assert name == "foo_bar_baz" + name = _name(loggers, "-") + assert name == "foo-bar-baz" + + +def test_version(tmpdir): + """Verify names of loggers are concatenated properly.""" + logger1 = CSVLogger(tmpdir, version=0) + logger2 = CSVLogger(tmpdir, version=2) + logger3 = CSVLogger(tmpdir, version=1) + logger4 = CSVLogger(tmpdir, version=0) + loggers = [logger1, logger2, logger3, logger4] + version = _version([]) + assert version == "" + version = _version([logger3]) + assert version == 1 + version = _version(loggers) + assert version == "0_2_1" + version = _version(loggers, "-") + assert version == "0-2-1"