Skip to content

Commit

Permalink
wip: top level params and metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Nov 4, 2022
1 parent d7152e3 commit eb9c37b
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 0 deletions.
23 changes: 23 additions & 0 deletions dvc/repo/metrics/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,28 @@ def _to_fs_paths(metrics: List[Output]) -> StrPaths:
return result


def _collect_top_level_metrics(repo):
from dvc.dvcfile import Dvcfile
from dvc.stage import PipelineStage

files = []
dvcfiles = {
stage.dvcfile
for stage in repo.index.stages
if isinstance(stage, PipelineStage)
}
dvcfiles.add(Dvcfile(repo, repo.dvcfs.from_os_path("dvc.yaml")))
for dvcfile in dvcfiles:
wdir = repo.dvcfs.path.parent(repo.dvcfs.from_os_path(dvcfile.path))
try:
metrics = dvcfile.load().get("metrics", [])
except Exception: # pylint: disable=broad-except
logger.debug("", exc_info=True)
continue
files.extend(repo.dvcfs.path.join(wdir, file) for file in metrics)
return files


def _collect_metrics(repo, targets, revision, recursive):
metrics, fs_paths = collect(
repo,
Expand Down Expand Up @@ -99,6 +121,7 @@ def _read_metrics(repo, metrics, rev, onerror=None):

def _gather_metrics(repo, targets, rev, recursive, onerror=None):
metrics = _collect_metrics(repo, targets, rev, recursive)
metrics.extend(_collect_top_level_metrics(repo))
return _read_metrics(repo, metrics, rev, onerror=onerror)


Expand Down
32 changes: 32 additions & 0 deletions dvc/repo/params/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ def _is_params(dep: "Output"):
return isinstance(dep, ParamsDependency)


def _collect_top_level_params(repo):
from dvc.dvcfile import Dvcfile

files = []
dvcfiles = {
stage.dvcfile
for stage in repo.index.stages
if isinstance(stage, PipelineStage)
}
dvcfiles.add(Dvcfile(repo, repo.dvcfs.from_os_path("dvc.yaml")))
for dvcfile in dvcfiles:
wdir = repo.dvcfs.path.parent(repo.dvcfs.from_os_path(dvcfile.path))
try:
params = dvcfile.load().get("params", [])
except Exception: # pylint: disable=broad-except
logger.debug("", exc_info=True)
continue
files.extend(repo.dvcfs.path.join(wdir, file) for file in params)
return files


def _collect_configs(
repo: "Repo", rev, targets=None, deps=False, stages=None
) -> Tuple[List["Output"], List[str]]:
Expand Down Expand Up @@ -93,7 +114,17 @@ def _read_params(
else:
fs_paths += [param.fs_path for param in params]

relpath = ""
if repo.root_dir != repo.fs.path.getcwd():
relpath = repo.fs.path.relpath(repo.root_dir, repo.fs.path.getcwd())

for fs_path in fs_paths:
rel_param_path = os.path.join(relpath, *repo.fs.path.parts(fs_path))
if not repo.fs.isfile(fs_path):
if repo.fs.isfile(rel_param_path):
fs_path = rel_param_path
else:
continue
from_path = _read_fs_path(repo.fs, fs_path, onerror=onerror)
if from_path:
name = os.sep.join(repo.fs.path.relparts(fs_path))
Expand Down Expand Up @@ -176,6 +207,7 @@ def _gather_params(
param_outs, params_fs_paths = _collect_configs(
repo, rev, targets=targets, deps=deps, stages=stages
)
params_fs_paths.extend(_collect_top_level_params(repo=repo))
params = _read_params(
repo,
params=param_outs,
Expand Down
2 changes: 2 additions & 0 deletions dvc/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def validator(data):
PLOTS: Any(SINGLE_PLOT_SCHEMA, [Any(str, SINGLE_PLOT_SCHEMA)]),
STAGES: SINGLE_PIPELINE_STAGE_SCHEMA,
VARS_KWD: VARS_SCHEMA,
StageParams.PARAM_PARAMS: [str],
StageParams.PARAM_METRICS: [str],
}

COMPILED_SINGLE_STAGE_SCHEMA = Schema(SINGLE_STAGE_SCHEMA)
Expand Down

0 comments on commit eb9c37b

Please sign in to comment.