From e1ffaaa1569b49349206764b5cc3d51fe009ac01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Tue, 3 Jan 2023 17:18:25 +0545 Subject: [PATCH] internal: hoist resolver/top-level metrics/plots/params data into file level --- dvc/dvcfile.py | 53 +++++++++++++++++-- dvc/repo/stage.py | 22 ++------ dvc/stage/__init__.py | 1 + dvc/stage/loader.py | 21 +++----- tests/func/test_stage.py | 2 + tests/unit/stage/test_loader_pipeline_file.py | 16 +++--- 6 files changed, 73 insertions(+), 42 deletions(-) diff --git a/dvc/dvcfile.py b/dvc/dvcfile.py index c409aebb26..e7d5055e07 100644 --- a/dvc/dvcfile.py +++ b/dvc/dvcfile.py @@ -1,7 +1,18 @@ import contextlib import logging import os -from typing import TYPE_CHECKING, Any, Callable, Tuple, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Tuple, + TypeVar, + Union, +) + +from funcy import cached_property from dvc.exceptions import DvcException from dvc.parsing.versions import LOCKFILE_VERSION, SCHEMA_KWD @@ -180,6 +191,10 @@ class SingleStageFile(FileMixin): from dvc.schema import COMPILED_SINGLE_STAGE_SCHEMA as SCHEMA from dvc.stage.loader import SingleStageLoader as LOADER + metrics: List[str] = [] + plots: Dict[str, Any] = {} + params: List[str] = [] + @property def stage(self): data, raw = self._load() @@ -223,6 +238,11 @@ class PipelineFile(FileMixin): def _lockfile(self): return Lockfile(self.repo, os.path.splitext(self.path)[0] + ".lock") + def _reset(self): + self.__dict__.pop("contents", None) + self.__dict__.pop("lockfile_contents", None) + self.__dict__.pop("resolver", None) + def dump( self, stage, update_pipeline=True, update_lock=True, **kwargs ): # pylint: disable=arguments-differ @@ -278,11 +298,36 @@ def stage(self): "PipelineFile has multiple stages. Please specify it's name." ) + @cached_property + def contents(self): + return self._load()[0] + + @cached_property + def lockfile_contents(self): + return self._lockfile.load() + + @cached_property + def resolver(self): + from .parsing import DataResolver + + wdir = self.repo.fs.path.parent(self.path) + return DataResolver(self.repo, wdir, self.contents) + @property def stages(self): - data, _ = self._load() - lockfile_data = self._lockfile.load() - return self.LOADER(self, data, lockfile_data) + return self.LOADER(self, self.contents, self.lockfile_contents) + + @property + def metrics(self): + return self.contents.get("metrics", []) + + @property + def plots(self): + return self.contents.get("plots", {}) + + @property + def params(self): + return self.contents.get("params", []) def remove(self, force=False): if not force: diff --git a/dvc/repo/stage.py b/dvc/repo/stage.py index b1f3008cb3..74c71e1cd9 100644 --- a/dvc/repo/stage.py +++ b/dvc/repo/stage.py @@ -517,7 +517,6 @@ def _load_file(self, path): path = self._get_filepath(path) dvcfile = Dvcfile(self.repo, path) - # `dvcfile.stages` is not cached stages = dvcfile.stages # type: ignore if isinstance(stages, SingleStageLoader): @@ -526,13 +525,7 @@ def _load_file(self, path): assert isinstance(stages, StageLoader) keys = self._get_keys(stages) stages_ = [stages[key] for key in keys] - - return ( - stages_, - stages.metrics_data, - stages.plots_data, - stages.params_data, - ) + return stages_, dvcfile def _collect_all_from_repo( self, onerror: Callable[[str, Exception], None] = None @@ -583,12 +576,7 @@ def is_out_or_ignored(root, directory): for file in filter(dvcfile_filter, files): file_path = self.fs.path.join(root, file) try: - ( - new_stages, - new_metrics, - new_plots, - new_params, - ) = self._load_file(file_path) + new_stages, dvcfile = self._load_file(file_path) except DvcException as exc: if onerror: onerror(relpath(file_path), exc) @@ -596,11 +584,11 @@ def is_out_or_ignored(root, directory): raise stages.extend(new_stages) - if new_metrics: + if new_metrics := dvcfile.metrics: metrics[file_path] = new_metrics - if new_plots: + if new_plots := dvcfile.plots: plots[file_path] = new_plots - if new_params: + if new_params := dvcfile.params: params[file_path] = new_params outs.update( diff --git a/dvc/stage/__init__.py b/dvc/stage/__init__.py index 6bdb5446c6..c75210bace 100644 --- a/dvc/stage/__init__.py +++ b/dvc/stage/__init__.py @@ -775,6 +775,7 @@ def addressing(self): return f"{super().addressing}:{self.name}" def reload(self): + self.dvcfile._reset() # pylint: disable=protected-access return self.dvcfile.stages[self.name] def _status_stage(self, ret): diff --git a/dvc/stage/loader.py b/dvc/stage/loader.py index 93092e3894..9f4412a05d 100644 --- a/dvc/stage/loader.py +++ b/dvc/stage/loader.py @@ -6,7 +6,7 @@ from funcy import cached_property, get_in, lcat, once, project from dvc import dependency, output -from dvc.parsing import FOREACH_KWD, JOIN, DataResolver, EntryNotFound +from dvc.parsing import FOREACH_KWD, JOIN, EntryNotFound from dvc.parsing.versions import LOCKFILE_VERSION from dvc_data.hashfile.hash_info import HashInfo from dvc_data.hashfile.meta import Meta @@ -20,13 +20,16 @@ class StageLoader(Mapping): - def __init__(self, dvcfile, data, lockfile_data=None): + def __init__( + self, + dvcfile, + data, + lockfile_data=None, + ): self.dvcfile = dvcfile + self.resolver = self.dvcfile.resolver self.data = data or {} self.stages_data = self.data.get("stages", {}) - self.metrics_data = self.data.get("metrics", []) - self.params_data = self.data.get("params", []) - self.plots_data = self.data.get("plots", {}) self.repo = self.dvcfile.repo lockfile_data = lockfile_data or {} @@ -36,11 +39,6 @@ def __init__(self, dvcfile, data, lockfile_data=None): else: self._lockfile_data = lockfile_data.get("stages", {}) - @cached_property - def resolver(self): - wdir = self.repo.fs.path.parent(self.dvcfile.path) - return DataResolver(self.repo, wdir, self.data) - @cached_property def lockfile_data(self): if not self._lockfile_data: @@ -174,9 +172,6 @@ def __init__(self, dvcfile, stage_data, stage_text=None): self.dvcfile = dvcfile self.stage_data = stage_data or {} self.stage_text = stage_text - self.metrics_data = [] - self.params_data = [] - self.plots_data = {} def __getitem__(self, item): if item: diff --git a/tests/func/test_stage.py b/tests/func/test_stage.py index ba0a0ec609..6596b15d58 100644 --- a/tests/func/test_stage.py +++ b/tests/func/test_stage.py @@ -317,6 +317,8 @@ def test_stage_remove_pipeline_stage(tmp_dir, dvc, run_copy): with dvc.lock: stage.remove() + + dvc_file._reset() assert stage.name not in dvc_file.stages assert "copy-bar-foobar" in dvc_file.stages diff --git a/tests/unit/stage/test_loader_pipeline_file.py b/tests/unit/stage/test_loader_pipeline_file.py index 5ef37201bc..8d77059726 100644 --- a/tests/unit/stage/test_loader_pipeline_file.py +++ b/tests/unit/stage/test_loader_pipeline_file.py @@ -247,11 +247,11 @@ def test_load_stage_wdir_and_path_correctly(dvc, stage_data, lock_data): def test_load_stage_mapping(dvc, stage_data, lock_data): dvcfile = Dvcfile(dvc, PIPELINE_FILE) - loader = StageLoader( - dvcfile, {"stages": {"stage": stage_data}}, {"stage": lock_data} - ) - assert len(loader) == 1 - assert "stage" in loader - assert "stage1" not in loader - assert loader.keys() == {"stage"} - assert isinstance(loader["stage"], PipelineStage) + dvcfile.contents = {"stages": {"stage": stage_data}} + dvcfile.lockfile_contents = {"stage": lock_data} + + assert len(dvcfile.stages) == 1 + assert "stage" in dvcfile.stages + assert "stage1" not in dvcfile.stages + assert dvcfile.stages.keys() == {"stage"} + assert isinstance(dvcfile.stages["stage"], PipelineStage)