Skip to content

Commit

Permalink
collect top-level params/plots/metrics together with stages
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Dec 1, 2022
1 parent aae6672 commit eb59c16
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 79 deletions.
38 changes: 37 additions & 1 deletion dvc/repo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,50 @@ def __init__(
if stages is not None:
self.stages: List["Stage"] = stages
self._collected_targets: Dict[int, List["StageInfo"]] = {}
self._metrics: Dict[str, List[str]] = {}
self._plots: Dict[str, Any] = {}
self._params: Dict[str, List[str]] = {}

@cached_property
def stages(self) -> List["Stage"]: # pylint: disable=method-hidden
# note that ideally we should be keeping this in a set as it is unique,
# hashable and has no concept of orderliness on its own. But we depend
# on this to be somewhat ordered for status/metrics/plots, etc.
return self._collect()

@cached_property
def _top_metrics(self):
self._collect()
return self._metrics

@cached_property
def _top_plots(self):
self._collect()
return self._plots

@cached_property
def _top_params(self):
self._collect()
return self._params

def _collect(self):
if "stages" in self.__dict__:
return self.stages

onerror = self.repo.stage_collection_error_handler
return self.stage_collector.collect_repo(onerror=onerror)

# pylint: disable=protected-access
(
stages,
metrics,
plots,
params,
) = self.stage_collector._collect_all_from_repo(onerror=onerror)
self.stages = stages
self._metrics = metrics
self._plots = plots
self._params = params
return stages

def __repr__(self) -> str:
from dvc.fs import LocalFileSystem
Expand Down
25 changes: 6 additions & 19 deletions dvc/repo/metrics/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,12 @@ def _to_fs_paths(metrics: List[Output]) -> StrPaths:


def _collect_top_level_metrics(repo):
from dvc.dvcfile import Dvcfile
from dvc.stage import PipelineStage

files = []
dvcfiles = {
stage.dvcfile
for stage in repo.index.stages
if isinstance(stage, PipelineStage)
}
dvcfiles.add(Dvcfile(repo, repo.dvcfs.from_os_path("dvc.yaml")))
for dvcfile in dvcfiles:
wdir = repo.dvcfs.path.parent(repo.dvcfs.from_os_path(dvcfile.path))
try:
metrics = dvcfile.load().get("metrics", [])
except Exception: # pylint: disable=broad-except
logger.debug("", exc_info=True)
continue
files.extend(repo.dvcfs.path.join(wdir, file) for file in metrics)
return files
top_metrics = repo.index._top_metrics # pylint: disable=protected-access
for dvcfile, metrics in top_metrics.items():
wdir = repo.dvcfs.path.relpath(
repo.dvcfs.path.parent(dvcfile), repo.root_dir
)
yield from (repo.dvcfs.path.join(wdir, file) for file in metrics)


def _collect_metrics(repo, targets, revision, recursive):
Expand Down
34 changes: 6 additions & 28 deletions dvc/repo/params/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,12 @@ def _is_params(dep: "Output"):


def _collect_top_level_params(repo):
from dvc.dvcfile import Dvcfile

files = []
dvcfiles = {
stage.dvcfile
for stage in repo.index.stages
if isinstance(stage, PipelineStage)
}
dvcfiles.add(Dvcfile(repo, repo.dvcfs.from_os_path("dvc.yaml")))
for dvcfile in dvcfiles:
wdir = repo.dvcfs.path.parent(repo.dvcfs.from_os_path(dvcfile.path))
try:
params = dvcfile.load().get("params", [])
except Exception: # pylint: disable=broad-except
logger.debug("", exc_info=True)
continue
files.extend(repo.dvcfs.path.join(wdir, file) for file in params)
return files
top_params = repo.index._top_params # pylint: disable=protected-access
for dvcfile, params in top_params.items():
wdir = repo.dvcfs.path.relpath(
repo.dvcfs.path.parent(dvcfile), repo.root_dir
)
yield from (repo.dvcfs.path.join(wdir, file) for file in params)


def _collect_configs(
Expand Down Expand Up @@ -114,17 +102,7 @@ def _read_params(
else:
fs_paths += [param.fs_path for param in params]

relpath = ""
if repo.root_dir != repo.fs.path.getcwd():
relpath = repo.fs.path.relpath(repo.root_dir, repo.fs.path.getcwd())

for fs_path in fs_paths:
rel_param_path = os.path.join(relpath, *repo.fs.path.parts(fs_path))
if not repo.fs.isfile(fs_path):
if repo.fs.isfile(rel_param_path):
fs_path = rel_param_path
else:
continue
from_path = _read_fs_path(repo.fs, fs_path, onerror=onerror)
if from_path:
name = os.sep.join(repo.fs.path.relparts(fs_path))
Expand Down
55 changes: 26 additions & 29 deletions dvc/repo/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,36 +439,33 @@ def _resolve_definitions(


def _collect_pipeline_files(repo, targets: List[str], props, onerror=None):
from dvc.dvcfile import PipelineFile

result: Dict[str, Dict] = {}
dvcfiles = {stage.dvcfile for stage in repo.index.stages}
for dvcfile in dvcfiles:
if isinstance(dvcfile, PipelineFile):
dvcfile_path = _relpath(repo.dvcfs, dvcfile.path)
dvcfile_defs = dvcfile.load().get("plots", {})
dvcfile_defs_dict: Dict[str, Union[Dict, None]] = {}
if isinstance(dvcfile_defs, list):
for elem in dvcfile_defs:
if isinstance(elem, str):
dvcfile_defs_dict[elem] = None
else:
k, v = list(elem.items())[0]
dvcfile_defs_dict[k] = v
else:
dvcfile_defs_dict = dvcfile_defs
resolved = _resolve_definitions(
repo.dvcfs,
targets,
props,
dvcfile_path,
dvcfile_defs_dict,
onerror=onerror,
)
dpath.util.merge(
result,
{dvcfile_path: resolved},
)
top_plots = repo.index._top_plots # pylint: disable=protected-access
for dvcfile, plots_def in top_plots.items():
dvcfile_path = _relpath(repo.dvcfs, dvcfile)
dvcfile_defs_dict: Dict[str, Union[Dict, None]] = {}
if isinstance(plots_def, list):
for elem in plots_def:
if isinstance(elem, str):
dvcfile_defs_dict[elem] = None
else:
k, v = list(elem.items())[0]
dvcfile_defs_dict[k] = v
else:
dvcfile_defs_dict = plots_def

resolved = _resolve_definitions(
repo.dvcfs,
targets,
props,
dvcfile_path,
dvcfile_defs_dict,
onerror=onerror,
)
dpath.util.merge(
result,
{dvcfile_path: resolved},
)
return result


Expand Down
101 changes: 101 additions & 0 deletions dvc/repo/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,104 @@ def is_out_or_ignored(root, directory):

def collect_repo(self, onerror: Callable[[str, Exception], None] = None):
return list(self._collect_repo(onerror))

def _load_file(self, path):
from dvc.dvcfile import Dvcfile
from dvc.stage.loader import SingleStageLoader, StageLoader

path = self._get_filepath(path)
dvcfile = Dvcfile(self.repo, path)
# `dvcfile.stages` is not cached
stages = dvcfile.stages # type: ignore

if isinstance(stages, SingleStageLoader):
stages_ = [stages[None]]
else:
assert isinstance(stages, StageLoader)
keys = self._get_keys(stages)
stages_ = [stages[key] for key in keys]

return (
stages_,
stages.metrics_data,
stages.plots_data,
stages.params_data,
)

def _collect_all_from_repo(
self, onerror: Callable[[str, Exception], None] = None
):
"""Collects all of the stages present in the DVC repo.
Args:
onerror (optional): callable that will be called with two args:
the filepath whose collection failed and the exc instance.
It can report the error to continue with the collection
(and, skip failed ones), or raise the exception to abort
the collection.
"""
from dvc.dvcfile import is_valid_filename
from dvc.fs import LocalFileSystem

scm = self.repo.scm
sep = self.fs.sep
outs: Set[str] = set()

is_local_fs = isinstance(self.fs, LocalFileSystem)

def is_ignored(path):
# apply only for the local fs
return is_local_fs and scm.is_ignored(path)

def is_dvcfile_and_not_ignored(root, file):
return is_valid_filename(file) and not is_ignored(
f"{root}{sep}{file}"
)

def is_out_or_ignored(root, directory):
dir_path = f"{root}{sep}{directory}"
# trailing slash needed to check if a directory is gitignored
return dir_path in outs or is_ignored(f"{dir_path}{sep}")

walk_iter = self.repo.dvcignore.walk(self.fs, self.repo.root_dir)
if logger.isEnabledFor(logging.TRACE): # type: ignore[attr-defined]
walk_iter = log_walk(walk_iter)

stages = []
metrics = {}
plots = {}
params = {}

for root, dirs, files in walk_iter:
dvcfile_filter = partial(is_dvcfile_and_not_ignored, root)
for file in filter(dvcfile_filter, files):
file_path = self.fs.path.join(root, file)
try:
(
new_stages,
new_metrics,
new_plots,
new_params,
) = self._load_file(file_path)
except DvcException as exc:
if onerror:
onerror(relpath(file_path), exc)
continue
raise

stages.extend(new_stages)
if new_metrics:
metrics[file_path] = new_metrics
if new_plots:
plots[file_path] = new_plots
if new_params:
params[file_path] = new_params

outs.update(
out.fspath
for stage in new_stages
for out in stage.outs
if out.protocol == "local"
)
dirs[:] = [d for d in dirs if not is_out_or_ignored(root, d)]
return stages, metrics, plots, params
6 changes: 6 additions & 0 deletions dvc/stage/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __init__(self, dvcfile, data, lockfile_data=None):
self.dvcfile = dvcfile
self.data = data or {}
self.stages_data = self.data.get("stages", {})
self.metrics_data = self.data.get("metrics", [])
self.params_data = self.data.get("params", [])
self.plots_data = self.data.get("plots", {})
self.repo = self.dvcfile.repo

lockfile_data = lockfile_data or {}
Expand Down Expand Up @@ -171,6 +174,9 @@ def __init__(self, dvcfile, stage_data, stage_text=None):
self.dvcfile = dvcfile
self.stage_data = stage_data or {}
self.stage_text = stage_text
self.metrics_data = []
self.params_data = []
self.plots_data = {}

def __getitem__(self, item):
if item:
Expand Down
7 changes: 5 additions & 2 deletions tests/unit/command/test_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_status_empty(dvc, mocker, capsys):

cmd = cli_args.func(cli_args)

spy = mocker.spy(cmd.repo.stage, "_collect_repo")
spy = mocker.spy(cmd.repo.stage, "_collect_all_from_repo")

assert cmd.run() == 0

Expand All @@ -108,7 +108,10 @@ def test_status_up_to_date(dvc, mocker, capsys, cloud_opts, expected_message):
mocker.patch.dict(cmd.repo.config, {"core": {"remote": "default"}})
mocker.patch.object(cmd.repo, "status", autospec=True, return_value={})
mocker.patch.object(
cmd.repo.stage, "_collect_repo", return_value=[object()], autospec=True
cmd.repo.stage,
"_collect_all_from_repo",
return_value=[[object()], {}, {}, {}],
autospec=True,
)

assert cmd.run() == 0
Expand Down

0 comments on commit eb59c16

Please sign in to comment.