From eb9c37b4f159ac44667e5dff1f3fe5211c308b17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Fri, 4 Nov 2022 16:46:00 +0545 Subject: [PATCH] wip: top level params and metrics --- dvc/repo/metrics/show.py | 23 +++++++++++++++++++++++ dvc/repo/params/show.py | 32 ++++++++++++++++++++++++++++++++ dvc/schema.py | 2 ++ 3 files changed, 57 insertions(+) diff --git a/dvc/repo/metrics/show.py b/dvc/repo/metrics/show.py index bb209b0c050..1f62935f616 100644 --- a/dvc/repo/metrics/show.py +++ b/dvc/repo/metrics/show.py @@ -33,6 +33,28 @@ def _to_fs_paths(metrics: List[Output]) -> StrPaths: return result +def _collect_top_level_metrics(repo): + from dvc.dvcfile import Dvcfile + from dvc.stage import PipelineStage + + files = [] + dvcfiles = { + stage.dvcfile + for stage in repo.index.stages + if isinstance(stage, PipelineStage) + } + dvcfiles.add(Dvcfile(repo, repo.dvcfs.from_os_path("dvc.yaml"))) + for dvcfile in dvcfiles: + wdir = repo.dvcfs.path.parent(repo.dvcfs.from_os_path(dvcfile.path)) + try: + metrics = dvcfile.load().get("metrics", []) + except Exception: # pylint: disable=broad-except + logger.debug("", exc_info=True) + continue + files.extend(repo.dvcfs.path.join(wdir, file) for file in metrics) + return files + + def _collect_metrics(repo, targets, revision, recursive): metrics, fs_paths = collect( repo, @@ -99,6 +121,7 @@ def _read_metrics(repo, metrics, rev, onerror=None): def _gather_metrics(repo, targets, rev, recursive, onerror=None): metrics = _collect_metrics(repo, targets, rev, recursive) + metrics.extend(_collect_top_level_metrics(repo)) return _read_metrics(repo, metrics, rev, onerror=onerror) diff --git a/dvc/repo/params/show.py b/dvc/repo/params/show.py index 684eb4037d2..cf031a6729d 100644 --- a/dvc/repo/params/show.py +++ b/dvc/repo/params/show.py @@ -35,6 +35,27 @@ def _is_params(dep: "Output"): return isinstance(dep, ParamsDependency) +def _collect_top_level_params(repo): + from dvc.dvcfile import Dvcfile + + files = [] + dvcfiles = { + stage.dvcfile + for stage in repo.index.stages + if isinstance(stage, PipelineStage) + } + dvcfiles.add(Dvcfile(repo, repo.dvcfs.from_os_path("dvc.yaml"))) + for dvcfile in dvcfiles: + wdir = repo.dvcfs.path.parent(repo.dvcfs.from_os_path(dvcfile.path)) + try: + params = dvcfile.load().get("params", []) + except Exception: # pylint: disable=broad-except + logger.debug("", exc_info=True) + continue + files.extend(repo.dvcfs.path.join(wdir, file) for file in params) + return files + + def _collect_configs( repo: "Repo", rev, targets=None, deps=False, stages=None ) -> Tuple[List["Output"], List[str]]: @@ -93,7 +114,17 @@ def _read_params( else: fs_paths += [param.fs_path for param in params] + relpath = "" + if repo.root_dir != repo.fs.path.getcwd(): + relpath = repo.fs.path.relpath(repo.root_dir, repo.fs.path.getcwd()) + for fs_path in fs_paths: + rel_param_path = os.path.join(relpath, *repo.fs.path.parts(fs_path)) + if not repo.fs.isfile(fs_path): + if repo.fs.isfile(rel_param_path): + fs_path = rel_param_path + else: + continue from_path = _read_fs_path(repo.fs, fs_path, onerror=onerror) if from_path: name = os.sep.join(repo.fs.path.relparts(fs_path)) @@ -176,6 +207,7 @@ def _gather_params( param_outs, params_fs_paths = _collect_configs( repo, rev, targets=targets, deps=deps, stages=stages ) + params_fs_paths.extend(_collect_top_level_params(repo=repo)) params = _read_params( repo, params=param_outs, diff --git a/dvc/schema.py b/dvc/schema.py index d18ad6af265..cd98e14592e 100644 --- a/dvc/schema.py +++ b/dvc/schema.py @@ -130,6 +130,8 @@ def validator(data): PLOTS: Any(SINGLE_PLOT_SCHEMA, [Any(str, SINGLE_PLOT_SCHEMA)]), STAGES: SINGLE_PIPELINE_STAGE_SCHEMA, VARS_KWD: VARS_SCHEMA, + StageParams.PARAM_PARAMS: [str], + StageParams.PARAM_METRICS: [str], } COMPILED_SINGLE_STAGE_SCHEMA = Schema(SINGLE_STAGE_SCHEMA)