diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index cfb8375e..103587c3 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -44,6 +44,7 @@ def __init__( report: Optional[str] = "auto", save_dvc_exp: bool = False, dvcyaml: bool = True, + cache_images: bool = False, ): super().__init__() self._prefix = prefix @@ -52,6 +53,7 @@ def __init__( "report": report, "save_dvc_exp": save_dvc_exp, "dvcyaml": dvcyaml, + "cache_images": cache_images, } if dir is not None: self._live_init["dir"] = dir diff --git a/tests/test_frameworks/test_lightning.py b/tests/test_frameworks/test_lightning.py index 873a8896..8aa4fbf5 100644 --- a/tests/test_frameworks/test_lightning.py +++ b/tests/test_frameworks/test_lightning.py @@ -144,7 +144,9 @@ def test_lightning_default_dir(tmp_dir): def test_lightning_kwargs(tmp_dir): model = LitXOR() # Handle kwargs passed to Live. - dvclive_logger = DVCLiveLogger(dir="dir", report="md", dvcyaml=False) + dvclive_logger = DVCLiveLogger( + dir="dir", report="md", dvcyaml=False, cache_images=True + ) trainer = Trainer( logger=dvclive_logger, max_epochs=2, @@ -156,6 +158,7 @@ def test_lightning_kwargs(tmp_dir): assert os.path.exists("dir") assert os.path.exists("dir/report.md") assert not os.path.exists("dir/dvc.yaml") + assert dvclive_logger.experiment._cache_images is True def test_lightning_steps(tmp_dir, mocker):