From d261760a2fc32c36faa3b074a6a14a69e65cee5d Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Tue, 4 Jul 2023 14:14:57 -0400 Subject: [PATCH] add cache_images to lightning (#614) * add cache_images to lightning * fix cache_images in lightning * add test for lightning cache_images --- src/dvclive/lightning.py | 2 ++ tests/test_frameworks/test_lightning.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index cfb8375e..103587c3 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -44,6 +44,7 @@ def __init__( report: Optional[str] = "auto", save_dvc_exp: bool = False, dvcyaml: bool = True, + cache_images: bool = False, ): super().__init__() self._prefix = prefix @@ -52,6 +53,7 @@ def __init__( "report": report, "save_dvc_exp": save_dvc_exp, "dvcyaml": dvcyaml, + "cache_images": cache_images, } if dir is not None: self._live_init["dir"] = dir diff --git a/tests/test_frameworks/test_lightning.py b/tests/test_frameworks/test_lightning.py index 873a8896..8aa4fbf5 100644 --- a/tests/test_frameworks/test_lightning.py +++ b/tests/test_frameworks/test_lightning.py @@ -144,7 +144,9 @@ def test_lightning_default_dir(tmp_dir): def test_lightning_kwargs(tmp_dir): model = LitXOR() # Handle kwargs passed to Live. - dvclive_logger = DVCLiveLogger(dir="dir", report="md", dvcyaml=False) + dvclive_logger = DVCLiveLogger( + dir="dir", report="md", dvcyaml=False, cache_images=True + ) trainer = Trainer( logger=dvclive_logger, max_epochs=2, @@ -156,6 +158,7 @@ def test_lightning_kwargs(tmp_dir): assert os.path.exists("dir") assert os.path.exists("dir/report.md") assert not os.path.exists("dir/dvc.yaml") + assert dvclive_logger.experiment._cache_images is True def test_lightning_steps(tmp_dir, mocker):